from __future__ import print_function
import os
import time
import logging
import argparse
from collections import defaultdict
from itertools import chain, groupby

import six

from ya.skynet.library import config
from ya.skynet.util.console import setProcTitle
from ya.skynet.util.misc import daemonthr

from library.config import detect_hostname
from infra.skylib import keys_storage
from infra.skylib.certificates import CAStorage
from infra.skylib.debug_server.reverse import ReplServer, repl_client

from .connserver import Server
from .slots.slot_sources import find_best_matching_slot
from .slots.slot_proxy import MTNSlotInfoProxy, MTNYpLiteSlotInfoProxy, MTNPodInfoProxy
from .slots.slot import YpLiteSlot
from .slots import iss_fetcher, yp_fetcher
from .portotools import get_portoconn
from . import logger


def parse_args(args=None):
    parser = argparse.ArgumentParser(description='portoshell daemon')
    parser.add_argument('-p', '--port', type=int, default=10045,
                        help='port to listen on')
    parser.add_argument('-s', '--ssh-port', type=int, default=10046,
                        help='SSH port to listen on')
    parser.add_argument('--yappi', action='store_true', default=False,
                        help=argparse.SUPPRESS)
    parser.add_argument('--container-ssh-port', type=int, default=22,
                        help='SSH port to listen in in container')
    parser.add_argument('--storage-dir', type=str,
                        help='directory to store keys in')
    parser.add_argument('--pidfile', type=str,
                        help='file to store server pid in')
    parser.add_argument('--noauth', action='store_true', default=False,
                        help='disable auth checking (use only for testing!)')
    parser.add_argument('-t', '--tools-tarball', type=str,
                        help='path to tarball with utils'
                        )
    parser.add_argument('--insecure-ca-files', help='comma-separated list of files with insecure CA keys')
    parser.add_argument('--secure-ca-files', help='comma-separated list of files with secure CA keys')
    parser.add_argument('--krl-file', help='path to files with KRL')
    # NOTE string "true" is used here so that we can substitute the value from genisys config
    # with option like --iss "${enable_iss}"
    parser.add_argument('--iss', default='true', choices=('True', 'true', 'False', 'false'),
                        help="if disabled, portoshell won't attempt to get slots info from ISS")
    parser.add_argument('--list-slots', action='store_true', default=False,
                        help='read available slots and exit (use for testing)')
    parser.add_argument('--repl-socket', default=None,
                        help='UDS to use for debug REPL')
    parser.add_argument('--attach', action='store_true', default=False,
                        help='attach to REPL socket')

    args = parser.parse_args(args)
    if args.tools_tarball is not None:
        args.tools_tarball = os.path.abspath(args.tools_tarball)

    args.iss = args.iss.lower() == 'true'

    if args.attach and not args.repl_socket:
        raise ValueError('When using --attach, --repl-socket should also be specified')

    return args


def filter_non_sshed(slot):
    if slot is not None and slot.mtn_ssh_enabled:
        return slot


def get_iss_conf():
    slots = list(chain.from_iterable(iss_fetcher.fetch_slots().values()))

    try:
        pods = yp_fetcher.fetch_slots()
    except Exception as e:
        if getattr(e, 'code', None) == 404:
            pods = []
        else:
            raise

    # filter out yp-lite pods which appear in two sources
    pods = filter(lambda pod: not any(pod.mtn_slot_container == slot.mtn_slot_container for slot in slots), pods)

    # now select only mtn-enabled nanny slots
    slots = filter(
        # select only non-None ones
        None,
        # select a single slot from each configuration set
        map(
            lambda group: filter_non_sshed(find_best_matching_slot(group[0][1], list(group[1]), permissive=True)),
            # reorder and group by slot inside mtn container
            groupby(
                sorted(
                    # remove non-MTN slots
                    six.moves.filter(
                        lambda slot: slot.mtn_interfaces and slot.mtn_hostname and slot.mtn_slot_container,
                        slots,
                    ),
                    key=lambda slot: (slot.mtn_slot_container, slot.slot)
                ),
                lambda slot: (slot.mtn_slot_container, slot.slot)
            )
        )
    )

    new_conf = defaultdict(dict)
    for pod in pods:
        new_conf[pod.mtn_slot_container][pod.box] = MTNPodInfoProxy(pod)

    for slot in slots:
        if isinstance(slot, YpLiteSlot):
            proxy = MTNYpLiteSlotInfoProxy(slot)
        else:
            proxy = MTNSlotInfoProxy(slot)

        new_conf[slot.mtn_slot_container][slot.slot] = proxy

    return new_conf


def _shutdown_obsolete_configurations(log, mtn_infos, new_conf, socket_infos):
    for mtn in mtn_infos.keys():
        if mtn not in new_conf:
            old = mtn_infos.pop(mtn)
            log.info("Shutting down old socket for MTN %s and interface map: %s", mtn, old['infos'])
            del socket_infos[old['socket']]
            old['socket'].close()


def _start_new_configurations(log, server, mtn_infos, new_conf, socket_infos, ssh_port):
    for mtn, data in six.iteritems(new_conf):
        if mtn in mtn_infos and mtn_infos[mtn]['infos'] != data:
            old = mtn_infos.pop(mtn)
            log.info("Replacing server for MTN %s and interface map %s with %s", mtn, old['infos'], data)
            del socket_infos[old['socket']]
            old['socket'].close()

        if mtn not in mtn_infos:
            log.info("Starting new server for MTN %s and interface map: %s", mtn, data)
            try:
                result = server.bind_socket(('::', ssh_port), mtn)
                if result is not None:
                    inode, sock = result
                    mtn_infos[mtn]['socket'] = sock
                    socket_infos[sock] = mtn_infos[mtn]['metadata'] = {
                        'netns_container': mtn,
                        'inode': inode,
                        'interface_map': {
                            iface: (slotinfo.get_lookup_args(), slotinfo.get_lookup_args(False))
                            for slotinfo in sorted(six.itervalues(data), key=lambda _: isinstance(_, MTNSlotInfoProxy))
                            for iface in slotinfo.mtn_interfaces
                        },
                    }
                    mtn_infos[mtn]['infos'] = data
                else:
                    log.error("Server creation failed for MTN %s and interface map: %s", mtn, data)
            except Exception as e:
                log.error("Failed to bind server for MTN %s and interface map %s: %s", mtn, data, e)
                info = mtn_infos.pop(mtn, None)
                if info and info['socket']:
                    info['socket'].close()


def _update_attributes(log, mtn_infos, new_conf):
    for mtn, data in six.iteritems(new_conf):
        if mtn not in mtn_infos:
            continue

        old_data = mtn_infos[mtn]['infos']

        for slot_name, slot in six.iteritems(data):
            if slot_name not in old_data:
                continue

            old_slot = old_data[slot_name]
            if (
                isinstance(old_slot, (MTNYpLiteSlotInfoProxy, MTNPodInfoProxy))
                and type(old_slot) == type(slot)
            ):
                if old_slot.acl != slot.acl:
                    log.info("updating acl of slot %s", slot)
                    old_slot.acl = slot.acl
                if old_slot.pod_labels != slot.pod_labels:
                    log.info("updating pod labels of slot %s", slot)
                    old_slot.pod_labels = slot.pod_labels


def _check_existing_configurations(log, server, mtn_infos, socket_infos, ssh_port):
    for mtn, data in list(mtn_infos.items()):
        old_inode = data['metadata']['inode']
        container = data['metadata']['netns_container']
        if old_inode is not None:
            try:
                root_pid = get_portoconn(False).Find(container).GetData('root_pid')
                inode = os.stat('/proc/%s/ns/net' % (root_pid,)).st_ino
            except Exception:
                log.exception("Checking server inode failed for MTN %s", mtn)
                continue

            if inode != old_inode:
                log.info("Restarting server for MTN %s and interface map: %s", mtn, data['infos'])

                socket_infos.pop(data['socket'], None)
                data['socket'].close()
                data['socket'] = None

                result = server.bind_socket(('::', ssh_port), mtn)
                if result is not None:
                    inode, sock = result
                    data['socket'] = sock
                    socket_infos[sock] = data['metadata'] = {
                        'netns_container': mtn,
                        'inode': inode,
                        'interface_map': {
                            iface: (slotinfo.get_lookup_args(), slotinfo.get_lookup_args(False))
                            for slotinfo in sorted(six.itervalues(data['infos']), key=lambda _: isinstance(_, MTNSlotInfoProxy))
                            for iface in slotinfo.mtn_interfaces
                        },
                    }
                else:
                    log.error("Server creation failed for MTN %s and interface map: %s",
                                mtn, data['infos'])


def watch_in_container_servers(log, ssh_port, server):
    socket_infos = server.mtn_sockets
    mtn_infos = defaultdict(lambda: {'infos': None, 'metadata': None, 'socket': None})

    while True:
        try:
            # {mtn => {slot => {MTNBaseInfoProxy}}
            new_conf = get_iss_conf()

            _shutdown_obsolete_configurations(log, mtn_infos, new_conf, socket_infos)
            _start_new_configurations(log, server, mtn_infos, new_conf, socket_infos, ssh_port)
            _update_attributes(log, mtn_infos, new_conf)
            _check_existing_configurations(log, server, mtn_infos, socket_infos, ssh_port)
        except Exception:
            log.exception("in-container slots reconfiguration failed")

        time.sleep(10)


def print_slots():
    import sys

    log = logging.getLogger()
    log.addHandler(logging.StreamHandler(sys.stdout))
    log.setLevel(logging.DEBUG)
    for slot in iss_fetcher._fetch_slots(log):
        print("  --- SLOT: %s" % (slot,))
        print('')

    for slot in yp_fetcher._fetch_slots(log):
        print(" --- POD: %s" % (slot,))
        print('')

    raise SystemExit(0)


def setup_yappi(enabled):
    if not enabled:
        return

    import yappi
    import threading

    def reporter():
        while True:
            time.sleep(300)
            yappi.stop()
            stats = yappi.get_func_stats()
            stats.save('/tmp/portoshell.profile', 'callgrind')
            yappi.clear_stats()
            yappi.start(True, True)

    yappi.set_clock_type('cpu')
    yappi.start(True, True)
    t = threading.Thread(target=reporter)
    t.daemon = True
    t.start()


def main():
    args = parse_args()

    if args.list_slots:
        print_slots()

    if args.attach:
        raise SystemExit(repl_client(args.repl_socket))

    setup_yappi(args.yappi)

    if args.pidfile:
        with open(args.pidfile, 'w') as f:
            f.write(str(os.getpid()))

    repl_server = None

    try:
        setProcTitle('portoshell [%s, %s]' % (args.port, args.ssh_port))
        cfg = config.query('skynet.services.portoshell')
        logger.setup_logging(cfg)
        logger.logging.captureWarnings(True)

        log = logger.logging.getLogger('portoshell')

        if args.repl_socket:
            repl_server = ReplServer(sock_path=args.repl_socket)
            repl_server.start()

        storage = None
        if args.storage_dir and args.iss:
            storage_log = log.getChild('fetch-keys')
            storage_log = logger.logging.MessageAdapter(
                storage_log,
                fmt='[%(pid)s] %(message)s',
                data={'pid': os.getpid()},
            )
            storage = keys_storage.KeysStorage(args.storage_dir, log=storage_log)
            daemonthr(storage.update_loop)

        if args.insecure_ca_files or args.secure_ca_files or args.krl_file:
            ca_storage_log = log.getChild('fetch-ca')
            ca_storage_log = logger.logging.MessageAdapter(
                ca_storage_log,
                fmt='[%(pid)s] %(message)s',
                data={'pid': os.getpid()},
            )

            ca_storage = CAStorage(
                insecure_ca_files=args.insecure_ca_files and args.insecure_ca_files.split(',') or [],
                secure_ca_files=args.secure_ca_files and args.secure_ca_files.split(',') or [],
                krl_file=args.krl_file or None,
                log=ca_storage_log,
            )
            daemonthr(ca_storage.update_loop)

        daemonthr(Server.collect_children, log)
        server = Server(
            log,
            args.port,
            args.ssh_port,
            check_auth=not args.noauth,
            tools_tarball=args.tools_tarball,
            iss=args.iss,
            keys_storage=storage,
            ca_storage=ca_storage,
            hostname=detect_hostname()
        )
        log.info("Started on ports [%s, %s]", args.port, args.ssh_port)
        if args.iss:
            daemonthr(
                watch_in_container_servers,
                log,
                args.container_ssh_port,
                server,
            )

        server.serve_forever()
    finally:
        if repl_server:
            repl_server.stop_repl()
        if args.pidfile:
            os.unlink(args.pidfile)


if __name__ == '__main__':
    main()
