import os
from collections import defaultdict
from itertools import chain, groupby

from ya.skynet.util import logging

import six

from ...utils import Threaded, as_user, sleep
from ..rootwrapper import get_portoconn

from ya.skynet.services.portoshell.slots import iss_fetcher, yp_fetcher
from ya.skynet.services.portoshell.slots.slot import YpLiteSlot
from ya.skynet.services.portoshell.slots.slot_proxy import MTNSlotInfoProxy, MTNYpLiteSlotInfoProxy, MTNPodInfoProxy
from ya.skynet.services.portoshell.slots.slot_sources import find_best_matching_slot


def filter_non_host_skyneted(slot):
    if slot is not None and slot.mtn_host_skynet_enabled:
        return slot


def _get_slots(log):
    slots = list(chain.from_iterable(iss_fetcher.fetch_slots().values()))

    try:
        pods = yp_fetcher._fetch_slots(log)
    except Exception as e:
        if getattr(e, 'code', None) == 404:
            pods = []
        else:
            raise

    # filter out yp-lite pods which appear in two sources
    pods = filter(lambda pod: not any(pod.mtn_slot_container == slot.mtn_slot_container for slot in slots), pods)

    slots = filter(
        # select only non-None ones
        None,
        # select a single slot from each configuration set
        map(
            lambda group: filter_non_host_skyneted(find_best_matching_slot(group[0][1], list(group[1]), permissive=True)),
            # reorder and group by slot inside mtn container
            groupby(
                sorted(
                    # remove non-MTN slots
                    six.moves.filter(
                        lambda slot: slot.mtn_interfaces and slot.mtn_hostname and slot.mtn_slot_container,
                        slots,
                    ),
                    key=lambda slot: (slot.mtn_slot_container, slot.slot)
                ),
                lambda slot: (slot.mtn_slot_container, slot.slot)
            )
        )
    )

    # {mtn => {slot => {MTNSlotInfoProxy}}
    new_conf = defaultdict(dict)

    for pod in pods:
        new_conf[pod.mtn_slot_container][pod.box] = MTNPodInfoProxy(pod)

    for slot in slots:
        if isinstance(slot, YpLiteSlot):
            proxy = MTNYpLiteSlotInfoProxy(slot)
        else:
            proxy = MTNSlotInfoProxy(slot)

        new_conf[slot.mtn_slot_container][slot.slot] = proxy

    return new_conf


def _make_adapter(log, container):
    cnt = (container[:15] + '...') if len(container) > 18 else container

    return logging.MessageAdapter(
        log,
        fmt='[cnt:%(cnt)s] %(message)s',
        data={'cnt': cnt.encode('utf-8')}
    )


class InContainerWatcher(Threaded):
    def __init__(self, log, privileges_lock, interval):
        super(InContainerWatcher, self).__init__()
        self.log = log
        self.privileges_lock = privileges_lock
        self.interval = interval
        self.portoconn = get_portoconn(privileges_lock)
        self.stopped = False

    def shutdown(self):
        self.stopped = True

    def watch(self, start_fn, privileges_lock):
        self.spawn(self._watcher_loop, start_fn, privileges_lock, _daemon=True)

    def _watcher_loop(self, start_fn, privileges_lock):
        servers = defaultdict(lambda: {'infos': None, 'server': None})
        try:
            while not self.stopped:
                try:
                    self._reconfigure(servers, start_fn, privileges_lock)
                except Exception:
                    self.log.exception("in-container slots reconfiguration failed")

                sleep(self.interval)
        finally:
            self.log.info("watcher loop shutting down all remaining servers")
            for srv in servers.values():
                srv['server'].shutdown()

    def _shutdown_obsolete_configurations(self, servers, new_conf):
        for mtn in servers.keys():
            if mtn not in new_conf:
                old = servers.pop(mtn)
                self.log.info("Shutting down old server for MTN %s and interface map: %s", mtn, old['infos'])
                old['server'].shutdown()

    def _start_new_configurations(self, servers, new_conf, start_fn, privileges_lock):
        for mtn, data in six.iteritems(new_conf):
            self.log.debug("found MTN %s and interface map: %s", mtn, data)
            if mtn in servers and servers[mtn]['infos'] != data:
                old = servers.pop(mtn)
                self.log.info("Replacing server for MTN %s and interface map %s with %s", mtn, old['infos'], data)
                old['server'].shutdown()

            if mtn not in servers:
                self.log.info("Starting new server for MTN %s and interface map: %s", mtn, data)
                try:
                    ifaces = {
                        iface: slotinfo
                        for slotinfo in six.itervalues(data)
                        for iface in slotinfo.mtn_interfaces
                    }
                    servers[mtn]['server'] = srv = start_fn(log=_make_adapter(self.log, mtn),
                                                            privileges_lock=privileges_lock,
                                                            interfaces=ifaces,
                                                            container=mtn,
                                                            portoconn=self.portoconn)
                    if srv.running:
                        servers[mtn]['infos'] = data
                    else:
                        self.log.error("Server creation failed for MTN %s and interface map: %s", mtn, data)
                        servers.pop(mtn, None)
                except Exception as e:
                    self.log.error("Failed to bind server for MTN %s and interface map %s: %s", mtn, data, e)
                    servers.pop(mtn, None)

    def _update_attributes(self, servers, new_conf):
        for mtn, data in six.iteritems(new_conf):
            if mtn not in servers:
                continue

            old_data = servers[mtn]['infos']

            for slot_name, slot in six.iteritems(data):
                if slot_name not in old_data:
                    continue

                old_slot = old_data[slot_name]
                if (
                    isinstance(old_slot, (MTNYpLiteSlotInfoProxy, MTNPodInfoProxy))
                    and type(old_slot) == type(slot)
                ):
                    if old_slot.acl != slot.acl:
                        self.log.info("Updating acl for slot %s", slot)
                        old_slot.acl = slot.acl
                    if old_slot.pod_labels != slot.pod_labels:
                        self.log.info("Updating pod labels for slot %s", slot)
                        old_slot.pod_labels = slot.pod_labels

    def _respawn_restarted(self, servers, start_fn, privileges_lock):
        for mtn, data in list(servers.items()):
            srv = data['server']
            if srv.inode() is not None:
                try:
                    root_pid = self.portoconn.Find(srv.netns_container).GetData('root_pid')
                    with privileges_lock:
                        inode = as_user('root', os.stat, '/proc/%s/ns/net' % (root_pid,)).st_ino
                except Exception:
                    self.log.exception("Checking server inode failed for MTN %s", mtn)
                else:
                    if inode != srv.inode():
                        self.log.info("Restarting server for MTN %s and interface map: %s", mtn, data['infos'])
                        srv.shutdown()
                        ifaces = {
                            iface: slotinfo
                            for slotinfo in six.itervalues(data['infos'])
                            for iface in slotinfo.mtn_interfaces
                        }
                        data['server'] = srv = start_fn(log=_make_adapter(self.log, mtn),
                                                        privileges_lock=privileges_lock,
                                                        interfaces=ifaces,
                                                        container=mtn,
                                                        portoconn=self.portoconn)
                        self.log.info("Starting new server for MTN %s and interface map: %s",
                                      mtn, data['infos'])
                        if srv.running is None:
                            self.log.error("Server creation failed for MTN %s and interface map: %s",
                                           mtn, data['infos'])
                            srv.pop(mtn, None)

    def _reconfigure(self, servers, start_fn, privileges_lock):
        # {mtn => {slot => {MTNBaseInfoProxy}}
        new_conf = _get_slots(self.log)
        self.log.debug("got configuration with %s items", len(new_conf.keys()))
        self._shutdown_obsolete_configurations(servers, new_conf)
        self._start_new_configurations(servers, new_conf, start_fn, privileges_lock)
        self._update_attributes(servers, new_conf)
        self._respawn_restarted(servers, start_fn, privileges_lock)
