#!/usr/bin/python

import base64
import os
import socket
import sys
import argparse
import json
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from SocketServer import ThreadingMixIn
from urlparse import urlparse, parse_qsl

from infra.vmagent.src.vmagent_pb import vmagent_pb2, vmagent_api_pb2
from library.python import resource

from . import log as setup_logging
from .auth import AuthError, TVMAuthContext, VMAGENT_ID, VMPROXY_ID
from .helpers import log_msg, log_trace
from .process import VMWorker

config = vmagent_pb2.VMConfig()
state = vmagent_pb2.VMState()

STATE_MAP = {
    vmagent_pb2.VMState.EMPTY: "empty",
    vmagent_pb2.VMState.CONFIGURED: "configured",
    vmagent_pb2.VMState.STOPPED: "stopped",
    vmagent_pb2.VMState.RUNNING: "running",
    vmagent_pb2.VMState.BUSY: "busy",
    vmagent_pb2.VMState.PREPARING: "preparing",
    vmagent_pb2.VMState.CRASHED: "crashed",
    vmagent_pb2.VMState.INVALID: "invalid"
}


class ActionError(Exception):
    pass


class RequestHandlerClass(BaseHTTPRequestHandler):

    def send_reply(self, data, msg="OK", encode=True):
        self.send_response(200, msg)
        self.send_header("Content-Type", "text/plain")
        encoded = base64.encodestring(data) if encode else data
        self.send_header("Content-Length", len(encoded))
        self.end_headers()
        self.wfile.write(encoded)

    def reply_error(self, code, message):
        self.send_error(code)
        message += '\n'
        self.send_header("Content-Type", "text/plain")
        self.send_header("Content-Length", len(message))
        self.wfile.write(message)

    def handle_action(self):
        if not self.server.worker.is_alive():
            self.reply_error(500, 'Worker is dead, restarting')
            self.server.start_worker()

        try:
            req = vmagent_api_pb2.VMActionRequest()
            data_len = int(self.headers.get("Content-Length", 0))
            payload = self.rfile.read(data_len)
            req.ParseFromString(base64.decodestring(payload))
            ret = self.server.worker.push_task(req)

            if ret:
                raise ActionError(ret)

        except Exception as e:
            self.reply_error(400, e.message)

        self.send_reply("")

    def handle_status(self):
        if not self.server.worker.is_alive():
            self.reply_error(500, 'Worker is dead, restarting')
            self.server.start_worker()

        try:
            resp = vmagent_api_pb2.VMStatusResponse()
            config = self.server.worker.get_config()
            state = self.server.worker.get_state()

            resp.config.CopyFrom(config)
            resp.state.CopyFrom(state)

            self.send_reply(resp.SerializeToString())
        except Exception as e:
            log_trace(sys.exc_info()[2])
            self.reply_error(500, e.message)

    def handle_ping(self):
        if not self.server.worker.is_alive():
            self.server.start_worker()

        self.send_reply("PONG", encode=False)

    def handle_unistat(self):
        state = self.server.worker.get_state()
        signals = []

        for state_type, name in STATE_MAP.iteritems():
            value = 1 if state.type == state_type else 0
            signals += [["{0}_ammv".format(name), value]]

        # retrieve net status
        link_alive, net_alive = self.server.worker.get_net_state()

        signals += [["link_alive_ammv", 1 if link_alive else 0]]
        signals += [["net_alive_ammv", 1 if net_alive else 0]]

        self.send_reply(json.dumps(signals), encode=False)

    def handle_debug(self):
        cmd = self.url_params.get('cmd', None)

        if cmd == 'shutdown':
            self.send_reply("")
            self.server.shutdown()
        elif cmd == 'restart_worker':
            self.send_reply("")
            self.server.start_worker()
        elif cmd == 'purge':
            self.send_reply("")
            self.server.worker.emergency_purge()
        else:
            self.reply_error(400, 'No such command')

    def handle_default(self):
        self.send_error(404)

    GET_noauth_paths = {
        "/ping": handle_ping,
        "/unistat": handle_unistat
    }
    GET_paths = {
        "/status": handle_status
    }
    POST_paths = {
        "/action": handle_action,
        "/debug": handle_debug
    }

    def auth_request(self):
        tvm, local = self.server.tvm_auth.extract_tvm_auth(self.headers)
        if local:
            self.server.tvm_auth.verify_local_token(local)

        elif tvm:
            self.server.tvm_auth.verify_request_ticket(tvm, VMPROXY_ID)

    def do_GET(self):
        self.parsed_url = urlparse(self.path)
        self.url_params = dict(parse_qsl(self.parsed_url.query))

        noauth_handler = self.GET_noauth_paths.get(self.parsed_url.path)
        if noauth_handler:
            noauth_handler(self)
            return

        try:
            self.auth_request()
        except AuthError as e:
            self.reply_error(401, e.message)
            return

        self.GET_paths.get(self.parsed_url.path, RequestHandlerClass.handle_default)(self)

    def do_POST(self):
        try:
            self.auth_request()
        except AuthError as e:
            self.reply_error(401, e.message)
            return

        self.parsed_url = urlparse(self.path)
        self.url_params = dict(parse_qsl(self.parsed_url.query))

        self.POST_paths.get(self.parsed_url.path, RequestHandlerClass.handle_default)(self)


class MyHTTPServer(ThreadingMixIn, HTTPServer):
    address_family = socket.AF_INET6
    test = False
    worker = None
    mode = "gencfg"

    def server_bind(self):
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.bind(self.server_address)

    def stop_worker(self):
        if self.worker:
            self.worker.stop()
            self.worker = None

    def start_worker(self):
        self.stop_worker()
        self.worker = VMWorker(mode=self.mode)
        self.worker.init()

        self.worker.start()


def serve(port, mode="gencfg"):
    try:
        os.setsid()
    except OSError as e:
        if e.errno != 1:
            raise e

    srv = MyHTTPServer(("::", port), RequestHandlerClass)
    secret = resource.find('/secrets/vmagent_secret').rstrip()
    srv.tvm_auth = TVMAuthContext(VMAGENT_ID, secret=secret)
    srv.mode = mode

    srv.start_worker()

    try:
        srv.serve_forever(poll_interval=1)
    finally:
        log_msg("Shutting down server")
        srv.stop_worker()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        dest="mode", help="vmagent mode", type=str,
        default="gencfg", choices=["gencfg", "yp"]
    )
    parser.add_argument("port", type=int, default=7255)
    args = parser.parse_args(sys.argv[1:])
    setup_logging.setup_logging()
    serve(args.port, args.mode)


if __name__ == "__main__":
    main()
