from __future__ import absolute_import

import os
import sys
import signal
import tempfile
import threading
from functools import partial

from ya.skynet.util import logging
from ya.skynet.util.console import setProcTitle
from ya.skynet.util.errors import formatException

from porto.exceptions import ContainerDoesNotExist
import six

from .. import cfg, containers_cfg, utils
from ..utils import log as root, sleep, run_daemon, configure_log

from . import Server
from .auth import create_auth
from .taskmgr import TaskManager, InPortoTaskManager
from .debugmanager import DebugManager
from .root_controls import RootControls
from .rootwrapper import get_portoconn
from .watcher import Watcher

try:
    from infra.skylib.keys_storage import KeysStorage
except ImportError:
    KeysStorage = None

try:
    from infra.skylib.certificates import CAStorage
except ImportError:
    CAStorage = None

try:
    if not getattr(sys, 'is_standalone_binary', False):
        raise ImportError("no portoshell here")

    from .container import InContainerWatcher
except ImportError:
    def InContainerWatcher(*args, **kwargs):
        raise RuntimeError("InContainerWatcher not supported")


def start_ppid_watcher(log):
    ppid = os.getppid()
    pid = os.getpid()

    def watch_ppid():
        try:
            while os.getppid() == ppid:
                sleep(1)
        except:
            pass
        else:
            log.warning("Parent died, sending SIGINT to self")
            os.kill(pid, signal.SIGINT)

    t = threading.Thread(target=watch_ppid, name='watch_ppid')
    t.daemon = True
    t.start()


def start_root_controls(path):
    if '/' not in path:
        path = '\x00' + path

    server = RootControls(path)
    t = threading.Thread(target=server.serve_forever, name='root_controls')
    t.daemon = True
    t.start()


def _run_server(srv):
    try:
        srv.run()
    except BaseException:
        srv.log.exception('Unhandled exception:')
    finally:
        srv.log.info('Stopped cqudp at %s', srv.protocol.listenaddr())


def _start_container_server(
    log,
    interfaces,
    container,
    portoconn,
    port, impl,
    cfg,
    privileges_lock,
    auth,
    debugmanager,
    threads,
    keys_storage,
    ca_storage,
):
    c = portoconn.Find(container)
    netns_pid = int(c.GetData('root_pid'))
    taskmgr = InPortoTaskManager(
        privileges_lock=privileges_lock,
        auth=auth,
        interface_map=interfaces,
        keys_storage=keys_storage,
        ca_storage=ca_storage,
        log=log.getChild('taskmgr'),
    )
    s = Server(None, port,
               impl=impl,
               send_own_ip=cfg.server.Transport.SendOwnIp,
               netns_pid=netns_pid,
               netns_container=container,
               privileges_lock=privileges_lock,
               taskmgr=taskmgr,
               debugmanager=debugmanager,
               log=log.getChild('server'),
               )
    threads.append(run_daemon(_run_server, s))
    s.spawned.wait(5)
    return s


def configure_yappi(mode):
    if mode is None:
        return

    import yappi
    import greenlet

    yappi.set_clock_type('cpu')
    yappi.set_context_id_callback(
        lambda: greenlet and id(greenlet.getcurrent() or 0)
    )
    yappi.set_context_name_callback(
        lambda: greenlet and type(greenlet.getcurrent()).__name__ or ''
    )


def check_porto_containers(privileges_lock, log):
    if getattr(containers_cfg, 'users', None) is None:
        return

    wrong_keys = filter(lambda key: not key.startswith('cqudp-'), containers_cfg.containers.keys())
    if wrong_keys:
        log.error("top-level container names MUST start with 'cqudp-', found: %s", wrong_keys)
        raise RuntimeError("cqudp config is broken")

    conn = get_portoconn(privileges_lock)

    to_remove = set()
    blocked_by_tasks = []

    for container in conn.ListContainers():
        try:
            if not container.name.startswith('cqudp-'):
                continue

            private = container.GetProperty('private')
            state = container.GetProperty('state')

            if private == 'CQUDP-TASK':
                if state in ('dead', 'stopped'):
                    container.Destroy()
                else:
                    blocked_by_tasks.append(container.name)
                continue

            if private != 'CQUDP' and '/' not in container.name:
                log.info('required to remove container (unknown origin): %r', container.name)
                to_remove.add(container.name)
                continue

            if state == 'dead':
                container.Destroy()
                continue

            name_parts = container.name.split('/')
            opts = containers_cfg.data

            for part in name_parts:
                opts = opts.get('containers', {}).get(part)
                if opts is None:
                    log.info('required to remove container (not used anymore): %r', container.name)
                    to_remove.add(container.name)
                    break
            else:
                for k, v in six.iteritems(opts.get('options', {})):
                    try:
                        container.SetProperty(k, v)
                    except Exception as e:
                        log.warning("%r: failed to set %r=%r: %s", container.name, k, v, e)
                        log.warning("%r will be removed", container.name)
                        to_remove.add(container.name)
                        break
        except ContainerDoesNotExist:
            pass

    for blocker in blocked_by_tasks:
        name_parts = blocker.split('/')
        for idx in range(len(name_parts)):
            name = '/'.join(name_parts[:idx + 1])
            if name in to_remove:
                log.info('%r cannot be removed (blocked by task %r)', name, blocker)
                to_remove.discard(name)

    for name in sorted(to_remove, key=lambda x: len(x), reverse=True):
        try:
            conn.Destroy(name)
        except Exception as e:
            log.info("failed to remove %r: %s", name, e)


def server(
    privileges_lock,
    msgpack_port=cfg.server.bus_port_msgpack,
    netlibus_port=cfg.server.netlibus_port,
    watcher=None,
    offset=0,
    bind_containers=False,
    log=None,
):
    log = log or root()

    keys_storage = None
    if KeysStorage is not None and cfg.server.Auth.storage_dir:
        keys_storage = KeysStorage(
            cfg.server.Auth.storage_dir.format(TempDir=tempfile.gettempdir()),
            log=log.getChild('keys_storage'),
            privileges_lock=privileges_lock,
        )
        run_daemon(keys_storage.update_loop)

    ca_storage = None
    if CAStorage is not None:
        ca_storage = CAStorage(
            insecure_ca_files=cfg.server.Auth.InsecureCertificateAuthorityFiles,
            secure_ca_files=cfg.server.Auth.SecureCertificateAuthorityFiles,
            krl_file=cfg.server.Auth.CertificateKRLFile or b'',
            serveradmins_file=cfg.server.Auth.CertificateServerAdminsFile,
            static_cas=cfg.server.Auth.CertificateAuthorityData,
            log=log.getChild('ca_storage'),
        )
        run_daemon(ca_storage.update_loop)

    servers = []
    # noinspection PyBroadException
    try:
        watcher.set_lock(privileges_lock)

        auth = create_auth(ca_storage=ca_storage, log=log.getChild('auth'))
        taskmgr = TaskManager(
            privileges_lock=privileges_lock,
            auth=auth,
            log=log.getChild('taskmgr'),
            keys_storage=keys_storage,
            ca_storage=ca_storage,
        )
        debugmanager = DebugManager(
            privileges_lock=privileges_lock,
            auth=auth,
            log=log.getChild('dbgmngr'),
        )

        in_container_watcher = InContainerWatcher(
            log,
            privileges_lock,
            300.
        ) if bind_containers else None
        threads = []
        impls = [
            (msgpack_port, 'msgpack'),
            (netlibus_port, 'netlibus')
        ]

        for p, impl in impls:
            if p == 0:
                continue

            if impl == 'netlibus' and cfg.transport.netlibus.Bandwidth:  # FIXME
                import netlibus
                log.info("netlibus bandwidth is set to %s", cfg.transport.netlibus.Bandwidth)
                netlibus.set_max_bandwidth_per_ip(cfg.transport.netlibus.Bandwidth)

            s = Server(
                None, p + offset,
                impl=impl,
                send_own_ip=cfg.server.Transport.SendOwnIp,
                taskmgr=taskmgr,
                debugmanager=debugmanager,
                log=log.getChild('server'),
            )
            t = run_daemon(_run_server, s)
            watcher.add(s)
            threads.append(t)

            if bind_containers and impl == 'netlibus':  # FIXME
                in_container_watcher.watch(
                    partial(_start_container_server,
                            port=p,
                            impl=impl,
                            cfg=cfg,
                            auth=auth,
                            debugmanager=debugmanager,
                            threads=threads,
                            keys_storage=keys_storage,
                            ca_storage=ca_storage,
                            ),
                    privileges_lock,
                )

        for t in threads:
            t.join()
    except BaseException:
        log.error('Unhandled exception: %s', formatException())
    finally:
        if privileges_lock is not None:
            privileges_lock.acquire()
        for s in servers:
            s.shutdown()
        log.info('Stopped cqudp')


def main(state_file, args):
    os.environ['PYTHONDONTWRITEBYTECODE'] = '1'
    os.environ['PYTHONNOUSERSITE'] = '1'
    configure_log(suffix='server', level=args.loglevel, debug=args.debug)
    log = logging.MessageAdapter(
        root(),
        fmt='[pid:%(pid)s] %(message)s',
        data={'pid': os.getpid()}
    )

    if args.watch_ppid:
        start_ppid_watcher(log)

    msgpack_port = args.msgpack_port
    netlibus_port = args.netlibus_port

    setProcTitle('cqudp [{}, {}]'.format(msgpack_port, netlibus_port))
    log.info('Initializing cqudp')

    utils.in_server = True
    privileges_lock = threading.RLock()

    if args.rc_socket and os.uname()[0].lower() == 'linux':
        start_root_controls(args.rc_socket)

    if args.taskpath:
        TaskManager.taskpath = args.taskpath

    if args.interpreter:
        TaskManager.interpreter = args.interpreter
        TaskManager.in_arcadia = False

    configure_yappi(args.yappi)

    if 'porto' in cfg.server.Executers:
        check_porto_containers(privileges_lock, log)

    server_watcher = Watcher(state_file, watch_ips=cfg.server.watch_ips)
    server_watcher_thread = run_daemon(server_watcher.run)
    try:
        run_daemon(
            server,
            privileges_lock,
            msgpack_port=msgpack_port,
            netlibus_port=netlibus_port,
            watcher=server_watcher,
            offset=args.port_offset,
            bind_containers=args.bind_containers,
            log=log,
        )

        while True:
            sleep(1)
    finally:
        server_watcher.stop()
        server_watcher_thread.join()
