from __future__ import absolute_import

import base64

import six

from .. import msgpackutils as msgpack
from ..transport.envelope import Envelope, build_tree
from ..utils import LRUCache, short, genuuid, bytesio, log as root
from ..transport.protocol import Node, handle

from ya.skynet.util.functional import singleton


class Server(Node):
    def __init__(
        self, ip, port,
        impl='msgpack',
        send_own_ip=False,
        netns_pid=None,
        netns_container=None,
        privileges_lock=None,
        taskmgr=None,
        debugmanager=None,
        log=None,
    ):
        super(Server, self).__init__(
            ip, port,
            impl=impl,
            send_own_ip=send_own_ip,
            netns_pid=netns_pid,
            privileges_lock=privileges_lock,
            log=log,
        )

        self.impl = impl
        self.netns_container = netns_container
        self.local_address = base64.encodestring(msgpack.dumps(('S', self.protocol.listenaddr())))
        self.taskmgr = taskmgr
        self.debugmanager = debugmanager
        self.combiner = TaskCombiner()

    def shutdown(self):
        super(Server, self).shutdown()
        if hasattr(self, 'taskmgr'):
            self.taskmgr.shutdown()

    def _route(self, task, addrs, hostid=None):
        env = Envelope(task, build_tree(addrs))
        env.path.insert(0, self.protocol._make_pathitem(hostid))

        for (addr, next_envelope) in env.next():
            self.protocol.route(next_envelope, addr)

    def _request_missing_part(self, msg, path, hostid):
        partid = self.combiner.missing(msg)
        self.log.debug("[%s] multipart requesting part %d", short(msg['taskid']), partid)
        self._route(part_request(msg['taskid'], partid), path[:1], hostid=hostid)

    @handle('task')
    def __task(self, msg, data, path, hostid, iface):
        self.log.info('new task %s[%s]', short(msg['uuid']), hostid)
        self.spawn(self.taskmgr.execute,
                   self._route,
                   self.local_address,
                   msg,
                   data,
                   path,
                   hostid,
                   iface,
                   _daemon=True)

    @handle('task_py2')
    def __task_py2(self, msg, data, path, hostid, iface):
        self.log.info('new task_py2 %s[%s]', short(msg['uuid']), hostid)
        self.spawn(self.taskmgr.execute_py2,
                   self._route,
                   self.local_address,
                   msg,
                   data,
                   path,
                   hostid,
                   iface,
                   _daemon=True)

    @handle('task_py3')
    def __task_py3(self, msg, data, path, hostid, iface):
        self.log.info('new task_py3 %s[%s]', short(msg['uuid']), hostid)
        self.spawn(self.taskmgr.execute_py3,
                   self._route,
                   self.local_address,
                   msg,
                   data,
                   path,
                   hostid,
                   iface,
                   _daemon=True)

    @handle('multi_task')
    def __multi_task(self, msg, data, path, hostids, iface):
        self.log.info('new multitask %s[%s]', short(msg['uuid']), hostids)
        self.spawn(self.taskmgr.container_execute,
                   self._route,
                   self.local_address,
                   msg,
                   data,
                   path,
                   hostids,
                   iface,
                   _daemon=True)

    @handle('task_portoshell_slow')
    def __task_portoshell(self, msg, data, path, hostids, iface):
        self.log.info('new portoshell task %s[%s]', short(msg['uuid']), hostids)
        self.spawn(self.taskmgr.portoshell_execute,
                   self._route,
                   self.local_address,
                   msg,
                   data,
                   path,
                   hostids,
                   iface,
                   _daemon=True)

    @handle('stop')
    def __stop(self, msg, data, path, hostid, iface):
        self.log.info('stop task %s[%s]', short(msg['taskid']), hostid)
        self.spawn(self.taskmgr.stop_task, msg['taskid'], hostid, _daemon=True)

    @handle('multipart')
    def __multipart(self, msg, data, path, hostid, iface):
        self.log.info('multipart %s [%d/%d]', short(msg['taskid']), msg['sn'], msg['total'] - 1)
        task = self.combiner.add(msg, data, path, hostid)
        if task is not None:
            self.handle(*task)
        else:
            self.spawn(self._request_missing_part, msg, path, hostid, _daemon=True)

    @handle('ping')
    def __ping(self, msg, data, path, hostid, iface):
        self.log.info('ping from %s', path[0])
        self.spawn(self.taskmgr.ping, self._route, self.local_address, msg, path, hostid, iface, _daemon=True)

    @handle('task_started')
    def __task_started(self, msg, data, path, hostid, iface):
        self.log.info('task notified about start %s[%s] on %s', short(msg['taskid']), path[0][0], path[0])
        self.spawn(self.taskmgr.task_started, msg['taskid'], path[0], path[0][0], _daemon=True)

    @handle('debug_request')
    def __debug_request(self, msg, data, path, hostid, iface):
        self.log.info('debug request from %s', path[0])
        self.spawn(self.debugmanager.process_request, self._route, msg, path, hostid, iface, _daemon=True)

    @handle('request_stats')
    def __stats(self, msg, data, path, hostid, iface):
        self.log.debug('stats request from %s', path[0])
        self.spawn(self.taskmgr.report_stats, self._route, msg, path, hostid, iface)


class TaskCombiner(object):
    def __init__(self):
        self.parts = LRUCache(20)

    def add(self, msg, data, path, hostid):
        parts = self.parts.setdefault(msg['taskid'], {})
        parts[msg['sn']] = (msg['data'], data)
        if len(parts) == msg['total']:
            del self.parts[msg['taskid']]

            m = bytesio()
            d = None
            for p in sorted(six.iteritems(parts)):
                if p[1][0] is not None:
                    m.write(p[1][0])
                if p[1][1] is not None:
                    d = d if d is not None else bytesio()
                    d.write(p[1][1])
            msg = msgpack.loads(m.getvalue())
            # FIXME data will fail on run_in_container, because it's not str, but {int => str} instead
            d = d.getvalue() if d is not None else None
            return msg, d, path, hostid

    def missing(self, msg):
        cnt = msg['total']
        parts = self.parts[msg['taskid']]
        for i in six.moves.xrange(cnt):
            if i not in parts:
                return i


def part_request(taskid, nextsn):
    return {
        'uuid': genuuid(),
        'taskid': taskid,
        'nextsn': nextsn,
        'type': 'part_request',
    }


@singleton
def log():
    return root().getChild('server')
