#!/usr/bin/python

import base64
import os
import socket
import sys
import argparse
import json
import porto
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from SocketServer import ThreadingMixIn
from urlparse import urlparse, parse_qsl
import logging
from threading import RLock

from infra.qyp.proto_lib import vmagent_pb2, vmagent_api_pb2
from library.python import resource

from infra.qyp.vmagent.src import log as setup_logging
from infra.qyp.vmagent.src.auth import AuthError, TVMAuthContext, VMAGENT_ID, VMPROXY_ID
from infra.qyp.vmagent.src.helpers import log_msg, log_trace, \
    create_ip6tables_ssh_syn_rule, get_ip6tables_ssh_syn_rule_count
from infra.qyp.vmagent.src.process import VMWorker
from infra.qyp.vmagent.src.config import VmagentContext
from infra.qyp.vmagent.src import resource_manager
from infra.qyp.vmagent.src import qemu_ctl
from infra.qyp.vmagent.src import qemu_launcher
from infra.qyp.vmagent.src import volume_manager

config = vmagent_pb2.VMConfig()
state = vmagent_pb2.VMState()

STATE_MAP = {
    vmagent_pb2.VMState.EMPTY: "empty",
    vmagent_pb2.VMState.CONFIGURED: "configured",
    vmagent_pb2.VMState.STOPPED: "stopped",
    vmagent_pb2.VMState.RUNNING: "running",
    vmagent_pb2.VMState.BUSY: "busy",
    vmagent_pb2.VMState.PREPARING: "preparing",
    vmagent_pb2.VMState.CRASHED: "crashed",
    vmagent_pb2.VMState.INVALID: "invalid"
}


class ActionError(Exception):
    pass


class RequestHandlerClass(BaseHTTPRequestHandler):
    server = None  # type: MyHTTPServer

    def send_reply(self, data, msg="OK", encode=True):
        self.send_response(200, msg)
        self.send_header("Content-Type", "text/plain")
        encoded = base64.encodestring(data) if encode else data
        self.send_header("Content-Length", len(encoded))
        self.end_headers()
        self.wfile.write(encoded)

    def reply_error(self, code, message):
        self.send_error(code)
        message += '\n'
        self.send_header("Content-Type", "text/plain")
        self.send_header("Content-Length", len(message))
        self.wfile.write(message)

    def handle_action(self):
        if not self.server.worker_is_alive():
            self.reply_error(500, 'Worker is dead, restarting')
            self.server.start_worker()

        try:
            req = vmagent_api_pb2.VMActionRequest()
            data_len = int(self.headers.get("Content-Length", 0))
            payload = self.rfile.read(data_len)
            req.ParseFromString(base64.decodestring(payload))
            if req.action == vmagent_api_pb2.VMActionRequest.STOP_QDMUPLOAD:
                ret = self.server.worker.stop_backup()
            else:
                ret = self.server.worker.push_task(req)

            if ret:
                raise ActionError(ret)

        except Exception as e:
            self.reply_error(400, e.message)

        self.send_reply("")

    def handle_status(self):
        if not self.server.worker.is_alive():
            self.reply_error(500, 'Worker is dead, restarting')
            self.server.start_worker()

        try:
            find_issues = int(self.url_params.get('find_issues', '0'))
            resp = vmagent_api_pb2.VMStatusResponse()
            config = self.server.worker.get_config()
            state = self.server.worker.get_state()
            data_transfer_state = self.server.worker.get_data_transfer_state()

            resp.config.CopyFrom(config)
            resp.state.CopyFrom(state)
            resp.data_transfer_state.CopyFrom(data_transfer_state)
            resp.vmagent_version = VmagentContext.VMAGENT_VERSION

            if find_issues:
                self.server.worker.find_issues(resp.state)

            self.send_reply(resp.SerializeToString())
        except Exception as e:
            log_trace(sys.exc_info()[2])
            self.reply_error(500, e.message)

    def handle_ping(self):
        if not self.server.worker.is_alive():
            self.server.start_worker()

        self.send_reply("PONG", encode=False)

    def handle_unistat(self):
        state = self.server.worker.get_state()
        signals = []

        for state_type, name in STATE_MAP.iteritems():
            value = 1 if state.type == state_type else 0
            signals += [["{0}_ammv".format(name), value]]

        signals += [["link_alive_ammv", 1 if state.link_alive else 0]]
        signals += [["net_alive_ammv", 1 if state.net_alive else 0]]

        ssh_syn_count = get_ip6tables_ssh_syn_rule_count()

        if ssh_syn_count is not None:
            signals += [["ssh_syn_count_dmmm", ssh_syn_count]]

        self.send_reply(json.dumps(signals), encode=False)

    def handle_debug(self):
        cmd = self.url_params.get('cmd', None)

        if cmd == 'shutdown':
            self.send_reply("")
            self.server.shutdown()
        elif cmd == 'restart_worker':
            self.send_reply("")
            self.server.start_worker()
        elif cmd == 'purge':
            self.send_reply("")
            self.server.worker.emergency_purge()
        else:
            self.reply_error(400, 'No such command')

    def handle_default(self):
        self.send_error(404)

    def handle_iss_hook_notify(self):
        files_base64_encoded = self.url_params.get('files_base64_encoded', None)
        files_changes = base64.b64decode(files_base64_encoded)
        result = {}
        statuses = {"!": 'change', "+": 'add', "-": 'delete'}
        for file_info in files_changes.strip().split(' '):
            resource_name, resource_status = file_info[1:], statuses.get(file_info[0], 'undefined')
            result[resource_name] = resource_status
            self.server.log.info("Dynamic Resource {} status: {}".format(resource_name, resource_status))

        if result.get('vmagent') == 'change':
            msg = "Vmagent Resource has been changed. Shutting down Worker and web server..."
            self.send_reply(msg)
            self.server.log.info(msg)
            self.server.stop_worker()
            self.server.worker_shutdown_needed = True
            self.server.log.info('Worker success stopped')
            self.server.shutdown()
            self.server.log.info('Web Server success stopped')
        elif result.get('vm_config_id') == 'change':
            msg = "Run handle config"
            self.server.start_worker()
            self.send_reply(msg)
        else:
            self.send_reply(json.dumps(result))

    def handle_graceful_stop(self):
        """
        Described in QEMUKVM-469
        1. Call worker's graceful_stop method which stops all running tasks and shutdown vm
        2. Set server's flag to prevent worker restart
        3. It is designed to be synchronious call, will reply after processing ends
        """
        self.server.log.info('Signal graceful stop received. Shutting down worker...')
        self.server.on_iss_hook_stop()
        self.send_reply('VM has been shut down')

    GET_noauth_paths = {
        "/ping": handle_ping,
        "/unistat": handle_unistat,
        "/iss_hook_notify": handle_iss_hook_notify,
    }
    GET_paths = {
        "/status": handle_status,
    }
    POST_paths = {
        "/graceful_stop": handle_graceful_stop,
        "/action": handle_action,
        "/debug": handle_debug
    }

    def auth_request(self):
        tvm, local = self.server.tvm_auth.extract_tvm_auth(self.headers)
        if local:
            self.server.tvm_auth.verify_local_token(local)

        elif tvm:
            self.server.tvm_auth.verify_request_ticket(tvm, VMPROXY_ID)

    def do_GET(self):
        self.parsed_url = urlparse(self.path)
        self.url_params = dict(parse_qsl(self.parsed_url.query))

        noauth_handler = self.GET_noauth_paths.get(self.parsed_url.path)
        if noauth_handler:
            noauth_handler(self)
            return

        try:
            self.auth_request()
        except AuthError as e:
            self.reply_error(401, e.message)
            return

        self.GET_paths.get(self.parsed_url.path, RequestHandlerClass.handle_default)(self)

    def do_POST(self):
        try:
            self.auth_request()
        except AuthError as e:
            self.reply_error(401, e.message)
            return

        self.parsed_url = urlparse(self.path)
        self.url_params = dict(parse_qsl(self.parsed_url.query))

        self.POST_paths.get(self.parsed_url.path, RequestHandlerClass.handle_default)(self)


class MyHTTPServer(ThreadingMixIn, HTTPServer):
    address_family = socket.AF_INET6
    test = False
    worker = None  # type: VMWorker
    mode = "gencfg"
    log = logging.getLogger('VMAgentServer')

    worker_lock = RLock()
    worker_shutdown_needed = False

    def server_bind(self):
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.bind(self.server_address)

    def worker_is_alive(self):
        if self.worker:
            return self.worker.is_alive()
        return False

    def stop_worker(self):
        with self.worker_lock:
            if self.worker:
                self.worker.stop()
                self.worker = None

    def on_iss_hook_stop(self):
        with self.worker_lock:
            self.worker_shutdown_needed = True
            if self.worker:
                self.worker.on_iss_hook_stop()

    def start_worker(self, context=None):
        self.stop_worker()
        if self.worker_shutdown_needed:
            return
        context = context or VmagentContext.build_from_pod_spec()
        with self.worker_lock:
            _resource_manager = resource_manager.ResourceManager(
                vm_id=context.VM_ID,
                cluster=context.CLUSTER,
                node_id=context.NODE_HOSTNAME,
                vmagent_version=context.VMAGENT_VERSION,
                extras_folder_path=context.EXTRAS_FOLDER_PATH,
                logs_folder_path=context.LOGS_FOLDER_PATH
            )
            _qemu_ctl = qemu_ctl.QemuCtl(
                porto_connection=porto.Connection(),
                mon_socket_file_path=context.MONITOR_PATH
            )
            self.worker = VMWorker(
                context=context,
                resource_manager=_resource_manager,
                qemu_ctl=_qemu_ctl,
                qemu_launcher=qemu_launcher.QEMULauncher(context.QEMU_SYSTEM_CMD_BIN_PATH),
                qemu_img_cmd=volume_manager.QEMUImgCmd(context.QEMU_IMG_CMD_BIN_PATH),
            )
            self.worker.init()
            self.worker.start()


def serve(port, vmagent_context, mode="gencfg"):
    try:
        os.setsid()
    except OSError as e:
        if e.errno != 1:
            raise e

    srv = MyHTTPServer(("::", port), RequestHandlerClass)
    secret = resource.find('/secrets/vmagent_secret').rstrip()
    srv.tvm_auth = TVMAuthContext(VMAGENT_ID, secret=secret)
    srv.mode = mode

    create_ip6tables_ssh_syn_rule()

    srv.start_worker(vmagent_context)

    try:
        srv.serve_forever(poll_interval=1)
    finally:
        log_msg("Shutting down server")
        srv.stop_worker()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        dest="mode", help="vmagent mode", type=str,
        default="gencfg", choices=["gencfg", "yp"]
    )
    parser.add_argument("port", type=int, default=7255)
    args = parser.parse_args(sys.argv[1:])
    vmagent_context = VmagentContext.build_from_pod_spec()

    setup_logging.setup_logging(vmagent_context)
    serve(args.port, vmagent_context, args.mode)


if __name__ == "__main__":
    main()
