# coding=utf-8
from __future__ import absolute_import

import six
import select
import socket
import errno
import sys
import os
import threading
import collections

from ya.skynet.util.functional import Cache
from ya.skynet.util import pickle

in_arcadia = bool(getattr(sys, 'is_standalone_binary', False))

if in_arcadia:
    from library.python.nstools.nstools import move_to_ns, Network, Uts
else:
    from ..nstools import move_to_ns, Network, Uts

from ..utils import FdHolder, Threaded, LRUCache, as_user, monotime
from ..utils import singleton, genuuid, short, sleep, fqdn, getaddrinfo, auto_restart, rotate_list, poll_select
from ..utils import log as root
from .. import msgpackutils as msgpack


MSG = 0
ACK = 1

TOTAL_TIMEOUT = 15
RECEIVE_TIMEOUT = 0.5

MINIMUM_TIMEOUT = 0.02
MAXIMUM_TIMEOUT = 5.0


def wait_read(fd, timeout):
    fds = [fd if isinstance(fd, int) else fd.fileno()]
    try:
        # poll and epoll in python always return raw fd, not object,
        # hence we cannot just check 'fd in res'
        return bool(poll_select(fds, [], [], timeout)[0])
    except select.error as e:
        if e.args[0] == errno.EINTR:
            return False
        raise


class SPSCQueue(object):
    _r = FdHolder('r')
    _w = FdHolder('w')

    def __init__(self):
        self._queue = collections.deque()
        self._r, self._w = os.pipe()

    def __len__(self):
        return len(self._queue)

    def fileno(self):
        return self._r

    def empty(self):
        return not bool(self._queue)

    def put(self, item):
        self._queue.append(item)
        os.write(self._w, '1')

    def get(self, block=True, timeout=None):
        ready_to_read = False

        if not block:
            if not self._queue:
                raise six.moves.queue.Empty
            ready_to_read |= wait_read(self._r, 0)
        elif timeout is None:
            while not self._queue:
                ready_to_read |= wait_read(self._r, 5.)
        elif timeout < 0:
            raise ValueError("'timeout' must be a non-negative number")
        else:
            deadline = monotime() + timeout
            while not self._queue:
                remaining = deadline - monotime()
                if remaining <= 0:
                    raise six.moves.queue.Empty
                ready_to_read |= wait_read(self._r, min(remaining, 5.))

        result = self._queue.popleft()
        try:
            while ready_to_read:
                os.read(self._r, 4096)
                ready_to_read = wait_read(self._r, 0)
        finally:
            return result

    def __del__(self):
        try:
            del self._r
        finally:
            del self._w


class MessageBus(Threaded):
    def __init__(self, port, ip=None, reuse=True, netns_pid=None, privileges_lock=None, log=None):
        super(MessageBus, self).__init__()
        self.log = log or _log()
        self.send_queue = SPSCQueue()
        self.recv_queue = SPSCQueue()
        self.sock = None
        self.port = port
        self.running = False
        self.inode = None
        self._in_flight = {}
        self.parts = LRUCache(100)
        self.counters = {'sent': 0, 'sent_envs': 0, 'resent': 0, 'expired': 0, 'received': 0, 'received_envs': 0}

        self.timeout = MINIMUM_TIMEOUT
        self.mean = RollingAverage(resistance=0.999, value=MINIMUM_TIMEOUT)
        self.variance = RollingAverage(resistance=0.999, value=0)

        self.resolver = Resolver(socket.has_ipv6, log=self.log)
        if socket.has_ipv6:
            if not ip:
                self.ip = ''
            elif ':' not in ip:
                self.ip = '::ffff:' + ip
            else:
                self.ip = ip
            self.create_socket(socket.AF_INET6, netns_pid, privileges_lock)
            self.sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
        else:
            self.ip = ip or ''
            self.create_socket(socket.AF_INET, netns_pid, privileges_lock)

        try:
            self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024 * 4)
        except socket.error:
            self.log.warning("not increasing socket receive buffer, no buffer space available")
        try:
            self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 1024 * 4)
        except socket.error:
            self.log.warning("not increasing socket send buffer, no buffer space available")

        if self.port and reuse:
            self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    def create_socket(self, family, netns_pid, privileges_lock):
        if netns_pid is None:
            self.sock = socket.socket(family, socket.SOCK_DGRAM)
            self.fqdn = fqdn()
            return

        if 'linux' not in sys.platform:
            raise RuntimeError("Namespaces aren't available in your OS")

        def _make_sock():
            with open('/proc/%s/ns/net' % (netns_pid,)) as netfd:
                with open('/proc/%s/ns/uts' % (netns_pid,)) as utsfd:
                    self.inode = os.fstat(netfd.fileno()).st_ino
                    move_to_ns(netfd, Network)
                    move_to_ns(utsfd, Uts)
                    self.sock = socket.socket(family, socket.SOCK_DGRAM)
                    self.fqdn = socket.getfqdn()

        with (privileges_lock if privileges_lock is not None else threading.RLock()):
            t = threading.Thread(target=as_user, args=('root', _make_sock,))
            t.start()
            t.join()

    def run(self):
        if self.running:
            return
        self.running = True

        self.sock.bind((self.ip, self.port))
        self.port = self.sock.getsockname()[1]

        for task in (self.__sender, self.__receiver, self.__resender):
            self.spawn(auto_restart(task), _daemon=True)

    def serialize(self, msg):
        raise NotImplementedError

    def unserialize(self, msg):
        raise NotImplementedError

    def shutdown(self):
        if not self.running:
            return
        self.running = False

    def send(self, envelope, addr):
        # addr has struct: (hostid, (host, port))
        # or (hostid, (host, port), ((family, packed_ip), ...))
        self.send_queue.put((self.serialize(envelope),) + tuple(addr[1:]))
        self.counters['sent_envs'] += 1

    def receive(self, block=True, timeout=None):
        try:
            tup = self.recv_queue.get(block=block, timeout=timeout)
            msg = self.unserialize(tup[0])
            addr = tup[1]
            iface = tup[2] if len(tup) > 2 else None
            return msg, addr, iface
        except six.moves.queue.Empty:
            raise MessageBus.Timeout()

    def listenaddr(self):
        return self.ip or self.fqdn, self.port

    def _adjust_timeout(self, rtt):
        rtt = min(MAXIMUM_TIMEOUT, max(MINIMUM_TIMEOUT, rtt))
        m = self.mean.push(rtt)
        v = self.variance.push((m - rtt) ** 2)
        self.timeout = m + 3 * (v ** 0.5)
        self.log.debug('New timeout is %s, rtt is %s', self.timeout, rtt)

    def __sender(self):
        while self.running:
            try:
                data = self.send_queue.get(timeout=1.0)
                (msg, dst) = data[:2]
                ips = None if len(data) < 3 else data[2]
            except six.moves.queue.Empty:
                continue

            uuid = genuuid()
            parts = split(msg, 1100)
            count = len(parts)

            try:
                addrs = self.resolver.resolve(dst)
                use_ips = False
            except RuntimeError:
                if not ips:
                    self.log.exception("exception during resolving addr: {}".format(dst), exc_info=sys.exc_info())
                    continue
                else:
                    addrs = [(ip, dst[1]) for ip in ips]
                    self.log.debug("failed to resolve addr %s, using provided ips: %s", dst, addrs)
                    use_ips = True

            for sn, part in enumerate(parts):
                part_msg = (uuid, sn, count, part)
                piece = self._in_flight[(uuid, sn)] = DeliveryPiece(part_msg,
                                                                    addrs[0],
                                                                    dst,
                                                                    ips=addrs if use_ips else None)

                try:
                    self._send_msg(part_msg, addrs[0], piece.send_time)
                    self.counters['sent'] += 1
                except EnvironmentError as e:
                    if e.errno in (errno.EHOSTUNREACH, errno.ENETUNREACH):
                        self.resolver.rotate_host(dst, addrs[0], candidates=addrs if use_ips else None)
                    self.log.error("exception during message sending to {}: {}".format(addrs[0], e))

    def __receiver(self):
        while self.running:
            if not wait_read(self.sock.fileno(), RECEIVE_TIMEOUT):
                continue

            x = self._recvfrom()
            if len(x) != 2 or len(x[0]) != 3:
                (kind, msg), addr = x
                if kind == MSG:
                    (uuid, sn, count, part) = msg
                    self._sendto((ACK, (uuid, sn)), addr)
                continue

            (kind, msg, timestamp), addr = x
            msg = tuple(msg)

            if kind == MSG:
                (uuid, sn, count, part) = msg
                self.log.debug('in  (MSG %s %s/%s)\t← %s', short(uuid), sn, count, addr)

                self._send_ack(uuid, sn, addr, timestamp)
                self._add_part(uuid, sn, count, part, addr)

            elif kind == ACK:
                self._adjust_timeout(monotime() - timestamp)

                (uuid, sn) = msg
                self.log.debug('in  (ACK %s %s)\t← %s', short(uuid), sn, addr)

                self._in_flight.pop(msg, None)

    def __resender(self):
        log_time = monotime()

        while self.running:
            sleep(max(self.timeout, 0.5))

            cur_time = monotime()
            if cur_time - log_time > 30:
                self.log.info(
                    '[statistics] %s in flight %s, sleep time %s, counters %s',
                    self.listenaddr(),
                    len(self._in_flight),
                    self.timeout,
                    self.counters
                )
                log_time = cur_time

            expired = []

            for (key, piece) in list(self._in_flight.items()):
                if cur_time - piece.start_time > TOTAL_TIMEOUT:
                    self.log.debug('(MSG %s) expired to %s', key, piece.dest)

                    expired.append(key)
                    self.resolver.forget(piece.dest)

                elif cur_time - piece.send_time > self.timeout:
                    next_addr = self.resolver.rotate_host(piece.dest, piece.addr, candidates=piece.ips)
                    piece.addr = next_addr
                    piece.send_time = monotime()

                    try:
                        self._send_msg(piece.msg, next_addr, piece.send_time)
                        self.counters['resent'] += 1
                    except EnvironmentError as e:
                        self.log.error('exception during message resending to {}: {}'.format(next_addr, e))

            for key in expired:
                self._in_flight.pop(key, None)
                self.counters['expired'] += 1

    def _add_part(self, uuid, sn, count, part, addr):
        parts = self.parts.setdefault(uuid, {})
        parts[sn] = part

        if len(parts) == count:
            del self.parts[uuid]

            msg = b''.join([parts[i] for i in six.moves.xrange(len(parts))])

            self.recv_queue.put((msg, addr))
            self.counters['received_envs'] += 1

    def _send_ack(self, uuid, sn, dst, timestamp):
        self._sendto((ACK, (uuid, sn), timestamp), dst)

        self.log.debug('out (ACK %s %s)\t→ %s', short(uuid), sn, dst)

    def _send_msg(self, msg, dst, timestamp):
        self._sendto((MSG, msg, timestamp), dst)

        self.log.debug('out (MSG %s %s/%s)\t→ %s', short(msg[0]), msg[1], msg[2], dst)

    def _sendto(self, msg, addr):
        self.sock.sendto(self.serialize(msg), addr)

    def _recvfrom(self):
        (msg, addr) = self.sock.recvfrom(8 * 1024)
        self.counters['received'] += 1

        return self.unserialize(msg), self.resolver.fixup_ip(addr)

    def __del__(self):
        # pracically never works if run() is called (thread stores reference to __self__)
        self.shutdown()

    class Timeout(Exception):
        pass


class Netlibus(object):
    def __init__(self, port=0, ip=None, reuse=False, netns_pid=None, privileges_lock=None, log=None):
        self.log = log or _log()
        self.ip = ip
        self.port = port
        self.resolver = Resolver(has_v6=True, log=self.log)

        import netlibus
        if netns_pid is None:
            self.bus = netlibus.MsgBus(port,
                                       timeout=TOTAL_TIMEOUT,
                                       logger=self.log.getChild('netlibus'),
                                       transfer_result_callback=self.__transfer_result_callback,
                                       )
        elif hasattr(netlibus.MsgBus, 'inode'):
            with (privileges_lock if privileges_lock is not None else threading.RLock()):
                self.bus = as_user('root', netlibus.MsgBus,
                                   port,
                                   timeout=TOTAL_TIMEOUT,
                                   logger=self.log.getChild('netlibus'),
                                   transfer_result_callback=self.__transfer_result_callback,
                                   netns_pid=netns_pid,
                                   )
        else:
            raise RuntimeError("netlibus doesn't support netns yet")

    @property
    def inode(self):
        return self.bus.inode

    def run(self):
        self.bus.start()

    def shutdown(self):
        # we may have encountered exception creating bus
        if getattr(self, 'bus', None) is not None:
            self.bus.transfer_result_callback = None
            self.bus.stop()

    # By the ancient tradition, addr contains hostid.
    def send(self, envelope, addr):
        try:
            dest = self.resolver.resolve(addr[1], fixup=False)
        except (EnvironmentError, RuntimeError):
            if len(addr) > 2:
                ips = addr[2]
                dest = [(ip, addr[1][1]) for ip in ips]
                self.log.debug("failed to resolve addr %s, using provided ips: %s", addr[1], dest)
            else:
                self.log.exception("cannot send envelope to {}".format(addr), exc_info=sys.exc_info())
                return

        msg = self.serialize(envelope)
        l = len(msg)

        try:
            # since netlibus sends everything TOO intensively,
            # we sleep before long messages to smooth network bursts a little
            if l > (1 << 20):
                sleep(5e-9 * l)  # 5ms per 1MB
            self.log.debug("out {} bytes → {}".format(len(msg), dest))
            self.bus.send_ex(msg, dest, dest=addr[1][0])
        except socket.gaierror:
            # netlibus resolves hostname right on the invokation
            self.log.exception("cannot send envelope to {}".format(dest), exc_info=sys.exc_info())

    def receive(self, block=True, timeout=None):
        try:
            tup = self.bus.receive(block=block, timeout=timeout)
            msg = tup[0]
            addr = tup[1].rsplit(':', 1)
            iface = tup[2].rsplit(':', 1)[0] if len(tup) > 2 else None
            if iface and iface.startswith('[') and iface.endswith(']'):
                iface = iface[1:-1]
            self.log.debug("in {} bytes ← {}".format(len(msg), addr))
            return self.unserialize(msg), (addr[0], int(addr[1])), iface
        except (six.moves.queue.Empty, self.bus.Timeout):
            raise MessageBus.Timeout()

    def listenaddr(self):
        return (self.ip or (self.bus.fqdn if hasattr(self.bus, 'fqdn') else fqdn()),
                self.bus.port())

    def serialize(self, msg):
        return msgpack.dumps(msg)

    def unserialize(self, msg):
        return msgpack.loads(msg)

    def __transfer_result_callback(self, host, addr, result):
        import netlibus

        try:
            if not host:
                return
            addr = tuple(addr.rsplit(':', 1))

            if result == netlibus.FAILED:
                self.resolver.forget(host)
            elif result == netlibus.RETRIED:
                self.resolver.rotate_host((host, addr[1]), addr)
        except Exception as e:
            self.log.exception("__transfer_result_callback failed: %s", e, exc_info=sys.exc_info())

    def __del__(self):
        self.shutdown()

    Timeout = MessageBus.Timeout


class PickleBus(MessageBus):
    def __init__(self, *args, **kwargs):
        raise RuntimeError("Constructing pickle bus is strictly forbidden")

    def serialize(self, msg):
        return pickle.dumps(msg)

    def unserialize(self, msg):
        return pickle.loads(msg)


class MsgpackBus(MessageBus):
    def serialize(self, msg):
        return msgpack.dumps(msg)

    def unserialize(self, msg):
        return msgpack.loads(msg)


class RollingAverage(object):
    """
    Rolling average

    Calculates exponentially weighted average value given the points pushed in:
        Avg(i) = Avg(i - 1) * p + Val(i) * (1 - p)
        Avg(0) = Val(0)
    where:
        `i` is the index of the value,
        `p` is the resistance factor (the greater `p`, the smaller effect will have each new value)
        `Val` is the next point pushed
    """

    __slots__ = [
        'value',
        'p',
        'q',
    ]

    def __init__(self, resistance, value=None):
        assert 0 <= resistance <= 1, "resistance factor is out of bounds [0, 1]"
        self.p = resistance
        self.q = 1 - resistance
        self.value = value

    def push(self, value):
        self.value = (self.value * self.p + value * self.q) if self.value else value
        return self.value

    def __float__(self):
        return self.value

    def __str__(self):
        return 'Average [{}]'.format(self.value)

    def __repr__(self):
        return '{}(resistance={!r}, value={!r})'.format(self.__class__.__name__, self.p, self.value)


class Resolver(object):
    def __init__(self, has_v6=True, log=None):
        self.log = log or _log()
        self.v6 = has_v6
        if has_v6:
            self.families = (socket.AF_INET6, socket.AF_INET)
        else:
            self.families = (socket.AF_INET,)
        self.resolve_cache = Cache(60 * 30)  # 30 minutes would be sufficient for everyone

    def fixup_ip(self, addr):
        ip = addr[0]
        if not self.v6 or ':' in ip:
            return addr
        return ('::ffff:' + ip,) + addr[1:]

    def rotate_host(self, host, addr, candidates=None):
        if candidates:
            i = (candidates.index(addr) + 1) % len(candidates)
            return candidates[i]

        if host not in self.resolve_cache:
            h = self.resolve(host)
        else:
            h = self.resolve_cache[host]
        if addr == h[0]:
            self.resolve_cache[host] = rotate_list(h, 1)
        return self.resolve_cache[host][0]

    def forget(self, host):
        try:
            del self.resolve_cache[host]
        except KeyError:
            pass

    def resolve(self, host, fixup=True):
        host = tuple(host)
        if host not in self.resolve_cache:
            addrs = None
            for i in six.moves.xrange(3):
                try:
                    addrs = self._resolve_with_tweaks(host)
                except EnvironmentError:
                    self.log.debug("can't resolve host {}: ".format(host))
                else:
                    break

            if not addrs:
                raise RuntimeError("Can't find suitable address for {!r}".format(host))

            self.resolve_cache[host] = [self.fixup_ip(a) for a in addrs] if fixup else addrs
            self.log.debug("resolve: %s → %s", host, addrs)

        return self.resolve_cache[host]

    def _resolve_with_tweaks(self, host):
        hostname, port = host[0] if isinstance(host[0], six.binary_type) else six.b(host[0]), int(host[1])
        addrs = {f: set() for f in self.families}

        def do_resolve(h, p):
            try:
                for addrinfo in getaddrinfo(h, p, 0, socket.SOCK_DGRAM):
                    if addrinfo[0] in self.families:
                        addrs[addrinfo[0]].add(addrinfo[4])
            except socket.error:
                pass

        do_resolve(hostname, port)
        if not addrs[self.families[0]] and hostname.endswith(b'yandex.ru'):
            do_resolve(hostname.replace(b'yandex.ru', b'search.yandex.net'), port)

        ret = []
        for f in self.families:
            ret.extend(addrs[f])

        return ret


class DeliveryPiece(object):
    __slots__ = ['msg', 'addr', 'dest', 'start_time', 'send_time', 'ips']

    def __init__(self, msg, addr, dest, ips):
        self.msg = msg
        self.addr = addr
        self.dest = dest
        self.ips = ips
        self.start_time = monotime()
        self.send_time = monotime()


def split(data, size):
    return [data[i:i + size] for i in six.moves.xrange(0, len(data), size)]


@singleton
def _log():
    return root().getChild('bus')
