from __future__ import absolute_import

from weakref import WeakValueDictionary
from collections import defaultdict

from ..transport.envelope import Envelope, build_tree
from ..transport.protocol import Node, handle, work_loop

from ..utils import genuuid, short, poll_select, LRUCache, sleep, monotime
from ..utils import log as root
from ..exceptions import CQueueRuntimeError
from .. import cfg
from .session import Session


class Client(Node):
    def __init__(self,
                 ip=None,
                 transport='msgpack',
                 session_class=Session,
                 select_function=poll_select,
                 allow_no_keys=False,
                 send_own_ip=False,
                 log=None,
                 ):
        self.log = log or root().getChild('client')

        if transport == 'netlibus' and cfg.transport.netlibus.Bandwidth:  # FIXME
            import netlibus
            self.log.info("netlibus bandwidth is set to %s", cfg.transport.netlibus.Bandwidth)
            netlibus.set_max_bandwidth_per_ip(cfg.transport.netlibus.Bandwidth)

        super(Client, self).__init__(ip=ip, impl=transport, send_own_ip=send_own_ip, log=self.log)

        self.allow_no_keys = allow_no_keys
        self.Session = session_class
        self.sessions = WeakValueDictionary()
        self._default_port = None
        self._default_cfg_port = _get_cfg_port(transport)
        self.select_function = select_function
        self.alive_hosts = LRUCache(1000)
        self.allowed_unpickles = defaultdict(set)

        self.spawn(self.run, _daemon=True)

    def default_port(self):
        return self._default_port or self._default_cfg_port

    def set_default_port(self, port):
        self._default_port = port

    def route(self, task, addrs, data=None, intermediates=None):
        self.protocol.route_tree(task, addrs, 'C', data, intermediates)

    def send_shutdown(self, host):
        envelope = Envelope(shutdown_msg(), build_tree([(0, host)]))
        envelope.path.insert(0, ('C', self.protocol.listenaddr()))
        for (addr, next_envelope) in envelope.next():
            self.protocol.route(next_envelope, addr)

    def start_ping_session(self, hosts, **kwargs):
        if not hosts:
            raise CQueueRuntimeError("Empty host list provided")

        s = self.Session(self,
                         None,
                         hosts,
                         None,
                         task_type='ping',
                         task_opts=kwargs,
                         log=self.log,
                         )
        self.sessions[s.taskid] = s
        s.log.info('starting ping session')

        return s

    def start_session(self, hosts, runnable, params, signer, description=None, task_type='task', **kwargs):
        if not self.allow_no_keys and not signer.fingerprints():
            raise CQueueRuntimeError('No private keys available. Cannot sign the task.')

        if not hosts:
            raise CQueueRuntimeError("Empty host list provided")

        s = self.Session(self,
                         signer,
                         hosts,
                         runnable,
                         params,
                         task_type=task_type,
                         task_opts=kwargs,
                         log=self.log,
                         )
        self.sessions[s.taskid] = s

        description = description or str(runnable)
        s.log.info('starting session %s', description)

        return s

    def start_custom_session(self,
                             hosts,
                             function,
                             params,
                             signer,
                             task_type,
                             username=None,
                             timeout=None,
                             runnable=b'',
                             description=None,
                             session_class=None,
                             **kwargs):
        if not self.allow_no_keys and not signer.fingerprints():
            raise CQueueRuntimeError('No private keys available. Cannot sign the task.')

        if not hosts:
            raise CQueueRuntimeError("Empty host list provided")

        session_class = session_class or self.Session

        # TODO: split exec_args and task_opts.
        task_opts = {'exec_fn': function}
        if timeout is not None:
            task_opts['session_timeout'] = timeout
        if username is not None:
            task_opts['user'] = username
        task_opts.update(kwargs)
        s = session_class(self,
                          signer,
                          hosts,
                          runnable=runnable,
                          task_hosts_data=params,
                          task_opts=task_opts,
                          task_type=task_type,
                          log=self.log,
                          )

        self.sessions[s.taskid] = s
        description = description or str(runnable or function)
        s.log.info('starting %s session %s', task_type, description)

        return s

    def _handle_in_session(self, datatype, msg, data, path, hostid, iface):
        taskid = msg['taskid']
        h = self.sessions.get(taskid, None)
        if h:
            if path[0][0] == 'I' and 'aggr_path' in msg:
                self.alive_hosts[msg['aggr_path'][0][0]] = monotime()
            else:
                self.alive_hosts[path[0][0]] = monotime()
            h.log.debug('handling %s in session', datatype)
            h.handle(msg, data, path, hostid, iface)
        else:
            index = msg.get('index')
            if index is not None:
                self.log.debug('[%s]: drop %s %s from %s', short(taskid), datatype, index, path[0])
            else:
                self.log.debug('[%s]: drop %s from %s', short(taskid), datatype, path[0])

    @handle('rpc')
    def __rpc(self, msg, data, path, hostid, iface):
        self._handle_in_session('rpc', msg, data, path, hostid, iface)

    @handle('result')
    def __result(self, msg, data, path, hostid, iface):
        self._handle_in_session('result', msg, data, path, hostid, iface)

    @handle('heartbeat')
    def __heartbeat(self, msg, data, path, hostid, iface):
        self._handle_in_session('heartbeat', msg, data, path, hostid, iface)

    @handle('response')
    def __response(self, msg, data, path, hostid, iface):
        self._handle_in_session('response', msg, data, path, hostid, iface)

    @handle('part_request')
    def __part_request(self, msg, data, path, hostid, iface):
        self._handle_in_session('part_request', msg, data, path, hostid, iface)

    @work_loop
    def __scheduling_loop(self):
        for k in list(self.sessions.keys()):
            session = self.sessions.get(k, None)
            if not session:
                continue

            good_hosts = session.get_good_hosts()
            hb_hosts, hb_msg, task_msgs, stop_hosts, stop_msg = session.schedule()

            for task_hosts, task_msg, task_host_data in task_msgs:
                session.log.debug('scheduled %s tasks', len(task_hosts))
                self.route(task_msg, task_hosts, data=task_host_data, intermediates=good_hosts)

            if stop_hosts:
                session.log.debug('scheduled %s stops', len(stop_hosts))
                self.route(stop_msg, stop_hosts, intermediates=good_hosts)

            session.log.debug('scheduled %s heartbeats', len(hb_hosts))
            tree_data = hb_hosts.copy()
            addrs = [(hostid, session.hosts[hostid].actual_addr) for hostid in list(tree_data.keys())]
            self.route(hb_msg, addrs, tree_data, good_hosts)

        sleep(1)


def shutdown_msg():
    return {
        'uuid': genuuid(),
        'type': 'shutdown',
    }


def _get_cfg_port(transport):
    if transport == 'msgpack':
        return cfg.server.bus_port_msgpack
    elif transport == 'netlibus':
        return cfg.server.netlibus_port
    raise RuntimeError('Unknown transport: {}'.format(transport))
