from __future__ import absolute_import

from collections import defaultdict
import collections
import threading
import functools
import socket
import sys

import six

from .envelope import Envelope, build_back_tree, build_tree, direct_path
from .messagebus import MsgpackBus, Netlibus, SPSCQueue, wait_read
from ..utils import Threaded, short, sleep, auto_restart, log as root

from ya.skynet.util.functional import singleton, cached
from ya.skynet.util.net.getifaddrs import getIfAddrs, Flags as IfFlags
from ya.skynet.util.net.socketserver import getValidAddrs

# CURRENT_VERSION = 10  # envelope and task are both classes
# CURRENT_VERSION = 11  # envelope.task renamed to envelope.msg
# CURRENT_VERSION = 12  # envelope.tree_data added for heartbeats
CURRENT_VERSION = 13    # new field: envelope.msgs is a list


def parse_message(data):
    v, d = data
    if isinstance(v, int):
        return parse_new_format(v, d)
    elif isinstance(v, six.string_types):
        return v, d
    else:
        raise RuntimeError("Unsupported format version: %r" % (v,))


def parse_new_format(version, data):
    if 10 <= version < 20:
        if version < CURRENT_VERSION:
            return convert_format_from(version, data)
        else:
            return data[0], data[1]
    else:
        raise RuntimeError('Unsupported format version: %s' % (version,))


def convert_format_from(old_version, data):
    # from 10 to 11
    cmd, env = data[0], data[1]
    env.msg = env.task

    # from 11 to 12
    if hasattr(env, 'data') and env.msg['type'] == 'heartbeat':
        env.msg['next'] = env.data

    # from 12 to 13
    if not hasattr(env, 'msgs'):
        env.msgs = [env.msg]

    return cmd, env


def _convert_format_to_old(env):
    # 13 to 12
    if hasattr(env, 'msgs'):
        env.msg = env.msgs[0]

    # any to 10
    if not hasattr(env, 'task'):
        env.task = env.msg


class Protocol(object):
    def __init__(self,
                 ip=None,
                 port=None,
                 impl='msgpack',
                 send_own_ip=False,
                 reuse=True,
                 netns_pid=None,
                 privileges_lock=None,
                 log=None,
                 bus=None
                 ):
        self.log = log or _log()
        if bus is not None:
            self.bus = bus
        else:
            self.log.info("initializing {} bus on {}:{}".format(impl, ip, port))
            self.bus = _create_bus(ip, port or 0,
                                   impl=impl,
                                   reuse=reuse,
                                   netns_pid=netns_pid,
                                   privileges_lock=privileges_lock,
                                   log=self.log.getChild('bus'),
                                   )
        self.ip = ip
        self.send_own_ip = send_own_ip

    def receive(self, block=True, timeout=None):
        try:
            data, addr, iface = self.bus.receive(block=block, timeout=timeout)
        except self.bus.Timeout:
            raise Protocol.Timeout()

        cmd, envelope = parse_message(data)
        envelope.path.append(self._make_pathitem(envelope.hostid))

        self.log.debug('received {} {} from {}'.format(cmd, short(envelope.msgs[0]['uuid']), addr))
        return cmd, envelope, iface

    def route(self, envelope, addr):
        _convert_format_to_old(envelope)

        msg = (CURRENT_VERSION, ('route', envelope))
        self.bus.send(msg, addr)

    def _make_pathitem(self, hostid):
        if self.send_own_ip:
            return (hostid, self.listenaddr(), self.ips())
        else:
            return (hostid, self.listenaddr())

    def route_back(self, msg, hostid, path):
        env = Envelope(msg, build_back_tree(path))
        env.path.insert(0, self._make_pathitem(hostid))
        for addr, next_env in env.next():
            next_env.aggr = True
            self.route(next_env, addr)

    def route_direct(self, msg, dest, hostid):
        env = Envelope(msg, direct_path([dest]))
        env.path.insert(0, self._make_pathitem(hostid))
        for addr, next_env in env.next():
            self.route(next_env, addr)

    def route_tree(self, msg, addrs, hostid, data=None, intermediates=None):
        if not addrs:
            return
        self._log_route([msg], addrs)

        env = Envelope(msg, build_tree(addrs, intermediates), tree_data=data)
        env.path.insert(0, self._make_pathitem(hostid))
        for (addr, next_envelope) in env.next():
            self.route(next_envelope, addr)

    def listenaddr(self):
        return self.bus.listenaddr()

    def ips(self):
        if self.ip is not None:
            return (self.ip,)
        else:
            return _get_ips()

    def inode(self):
        return self.bus.inode

    def shutdown(self):
        bus = getattr(self, 'bus', None)
        if bus is not None:
            bus.shutdown()

    def __del__(self):
        self.shutdown()

    def _log_route(self, msgs, addrs):
        if len(msgs) > 0:
            taskid = msgs[0].get('taskid', None)
            if not taskid:
                taskid = msgs[0].get('uuid', '')
            taskid = '[{:>8}]'.format(short(taskid))

            msgtypes = msgs[0].get('type', 'NOTYPE') if len(msgs) == 1 else '{} messages'.format(len(msgs))
            destination = addrs[0] if len(addrs) == 1 else '{} addresses'.format(len(addrs))
            self.log.debug('%s route %s to %s', taskid, msgtypes, destination)
        else:
            self.log.error('attempt to route empty list of messages to %s', addrs)

        self.log.debug(
            'route %s to %s',
            [(m.get('type', 'NOTYPE'), m.get('taskid', '')) for m in msgs],
            addrs
        )

    class Timeout(Exception):
        pass


class _Handle(object):
    def __init__(self, name):
        self.name = name

    def __call__(self, fn):
        @functools.wraps(fn)
        def decorated(*args, **kwargs):
            fn(*args, **kwargs)
        decorated._handles = self.name
        return decorated
handle = _Handle


def work_loop(fn):
    @functools.wraps(fn)
    def decorated(self, *args, **kwargs):
        while not self.stopped:
            fn(self, *args, **kwargs)
    decorated._work_loop = True
    return decorated


class Dispatcher(object):
    def __init__(self, log=None):
        self._handlers = self._register_handlers()
        self.log = log or _log()
        super(Dispatcher, self).__init__()

    def _register_handlers(self):
        handlers = {}

        for attr in dir(self):
            method = getattr(self, attr)
            if not isinstance(method, collections.Callable) or not hasattr(method, '_handles'):
                continue

            handlers[getattr(method, '_handles')] = method.__func__

        return handlers

    def handle(self, msg, data, path, hostid, iface):
        msgtype = msg['type']
        if msgtype not in self._handlers:
            self.log.info('[%s] unsupported message type `%s` from %s',
                          short(msg.get('taskid') or msg.get('uuid', 'unknown')),
                          msgtype,
                          path[0])
            return False

        return self._handlers[msgtype](self, msg, data, path, hostid, iface)


class Node(Dispatcher, Threaded):
    def __init__(self,
                 ip=None,
                 port=None,
                 impl='msgpack',
                 send_own_ip=False,
                 netns_pid=None,
                 privileges_lock=None,
                 log=None,
                 protocol=None,
                 ):
        super(Node, self).__init__(log=log)
        self.running = False
        self.stopped = False
        self.spawned = threading.Event()
        if protocol is not None:
            self.protocol = protocol
        else:
            self.protocol = Protocol(ip, port,
                                     impl=impl,
                                     send_own_ip=send_own_ip,
                                     netns_pid=netns_pid,
                                     privileges_lock=privileges_lock,
                                     log=log.getChild('proto') if log is not None else None,
                                     )
        self._aggr_queue = SPSCQueue()
        self._dispatch_queue = SPSCQueue()

    def run(self):
        if self.running or self.stopped:
            return

        self.running = True
        self.spawned.set()
        for fn_name in dir(self):
            fn = getattr(self, fn_name, None)
            if fn is not None and getattr(fn, '_work_loop', False):
                self.spawn(auto_restart(fn), _daemon=True)

        try:
            self.log.debug('%s started on %s', self.__class__.__name__, self.protocol.listenaddr())
            while not self.stopped:
                sleep(0)  # if monkey-patched, give gevent chance to context-switch
                self._receive_step()
        finally:
            self.shutdown()

    def _receive_step(self):
        try:
            item = self.protocol.receive(timeout=1.0)
        except Protocol.Timeout:
            return
        except Exception as e:
            self.log.exception('%s cannot receive the message: %s', self.protocol.listenaddr(), e)
            return

        self._dispatch_queue.put(item)

    def shutdown(self):
        if self.stopped:
            return

        self.log.debug('{} stopping on {}'.format(self.__class__.__name__, self.protocol.listenaddr()))

        self.running = False
        self.stopped = True
        self.protocol.shutdown()

    def inode(self):
        return self.protocol.inode()

    def _dispatch(self, cmd, envelope, iface):
        if cmd == 'route':
            if envelope.delivered():
                data = getattr(envelope, 'data', None)
                for m in envelope.msgs:
                    self.handle(m, data, envelope.path, envelope.hostid, iface)
            else:
                next_routes = list(envelope.next())
                # Aggregate if there is only one destination for the envelope,
                # AND if the envelope is specifically marked.
                if len(next_routes) == 1 and getattr(envelope, 'aggr', False):
                    self._aggr_queue.put((next_routes[0][0], envelope))
                else:
                    for addr, next_envelope in next_routes:
                        self.protocol.route(next_envelope, addr)

        else:
            self.log.warning("unknown protocol cmd: `{}`, ignoring".format(cmd))

    @work_loop
    def __dispatch_loop(self):
        while not self._dispatch_queue and not self.stopped:
            wait_read(self._dispatch_queue, 1.)

        if self.stopped:
            return

        self._dispatch_step()

    def _dispatch_step(self):
        while self._dispatch_queue:
            cmd, envelope, iface = self._dispatch_queue.get(block=False)
            # noinspection PyBroadException
            try:
                self._dispatch(cmd, envelope, iface)
            except Exception:
                self.log.exception('{} handle failed: '.format(cmd), exc_info=sys.exc_info())

    @work_loop
    def __aggr_loop(self):
        while not self._aggr_queue and not self.stopped:
            wait_read(self._aggr_queue, 1.)

        if self.stopped:
            return

        self._aggregate_step()

    def _aggregate_step(self):
        collected = defaultdict(list)

        while self._aggr_queue:
            dest, env = self._aggr_queue.get(block=False)
            for msg in env.msgs:
                msg.update({'aggr_path': env.path})
                collected[dest].append(msg)

        for dest, msgs in list(collected.items()):
            self.protocol.route_direct(msgs, dest, 'I')

    def __del__(self):
        self.shutdown()


def _create_bus(ip, port, impl='msgpack', reuse=True, netns_pid=None, privileges_lock=None, log=None):
    if impl == 'netlibus':
        bus = Netlibus(port, ip, netns_pid=netns_pid, privileges_lock=privileges_lock, log=log)
    elif impl == 'msgpack':
        bus = MsgpackBus(port, ip, reuse=reuse, netns_pid=netns_pid, privileges_lock=privileges_lock, log=log)
    else:
        raise RuntimeError("Unsupported bus type")
    bus.run()
    return bus


@cached(300)
def _get_ips():
    valid_addrs = getValidAddrs((socket.AF_INET, socket.AF_INET6), getIfAddrs())
    valid_addrs = filter(
        lambda addr: (addr.flags & IfFlags.IFF_LOOPBACK != IfFlags.IFF_LOOPBACK
                      and not addr.scopeid
                      and '%' not in addr.addr),
        valid_addrs
    )
    return [addr.addr for addr in valid_addrs]


@singleton
def _log():
    return root().getChild('protocol')
