import collections
import random
import time

import gevent.queue

from .component import Component

from .peer import Peer, PeerAddress
from .peer_collection import PeerCollection
from .peer_connection import PeerConnection


class Swarm(Component):
    SPEED_WINDOW = 5

    class ConnectFailed(Exception):
        pass

    class ConnectedToSelf(ConnectFailed):
        pass

    class AlreadyInHalfOpen(ConnectFailed):
        pass

    class HandshakeFailed(ConnectFailed):
        pass

    class NoSkybit(HandshakeFailed):
        pass

    class Abort(BaseException):
        pass

    class Deactivate(Abort):
        def __init__(self, message, conn):
            self.conn = conn
            super(Swarm.Deactivate, self).__init__(message)

    def __init__(self, handle, uid, parent=None):
        super(Swarm, self).__init__(logname='swm', parent=parent)

        self.handle = handle

        self.conns_halfopen = {}     # addr => conn
        self.conns_active = {}       # conn => peer

        self.uid = uid

        self.votenum = random.randint(0, 2 ** 32 - 1)

        self.peers = PeerCollection(parent=self)

        self.log.info('New swarm created')

        self._recv_window = collections.deque(maxlen=self.SPEED_WINDOW)
        self._send_window = collections.deque(maxlen=self.SPEED_WINDOW)
        self._recv_window_payload = collections.deque(maxlen=self.SPEED_WINDOW)
        self._send_window_payload = collections.deque(maxlen=self.SPEED_WINDOW)

    def add_candidate(self, ip, port, weight):
        addr = PeerAddress(ip, port)
        addr.weight = weight
        self.add_candidate_address(addr)

    def add_candidate_address(self, addr, defer=None):
        self.peers.add_address(addr, defer)

    def collect_peer(self, peer):
        return self.peers.collect(peer)

    def get_connect_candidate(self):
        """ Asked by handle for any new (or old) peer to try connect to. """

        candidate = self.peers.get_next(timeout=3)

        if candidate is None:
            self.peers.log_colstats('nomore')  # regular colstats printing with 6-char status
            return

        if isinstance(candidate, Peer):
            assert candidate.ou_address is not None
            return candidate.ou_address

        if isinstance(candidate, PeerAddress):
            return candidate

        self.log.error('Got invalid candidate from peers: %r', candidate)
        return None

    def connect_addr(self, addr, sock, payload=None):
        """ Create PeerConnection, attach to incoming connection or connect in case of outgoing. """

        if addr in self.conns_halfopen:
            self.log.debug('Do not connect address %s: already halfopen', addr)
            raise self.AlreadyInHalfOpen()

        if not sock:
            connected = False
            sock = self.handle.world.make_socket(uds=bool(self.handle.world.uds_proxy))
            outgoing = True
        else:
            connected = True
            outgoing = False

        conn = PeerConnection(
            self, (addr.ip, addr.port),
            sock, connected,
            log=self.log.getChild('conn')
        )

        self.conns_halfopen[addr] = conn

        try:
            with gevent.Timeout(1) as tout:
                try:
                    if not connected:
                        conn.log.debug('Connecting...')
                        conn.connect()
                    conn.start()

                except gevent.Timeout as ex:
                    if ex != tout:
                        raise
                    conn.log.warning('Unable to connect: %s %s: timed out', addr.ip, addr.port)
                    raise self.ConnectFailed()

                except Exception as ex:
                    conn.log.warning('Unable to connect: %s %s: %s', addr.ip, addr.port, ex)
                    raise self.ConnectFailed()

            # Handshake part

            try:
                with gevent.Timeout(5) as tout:
                    try:
                        conn.write_handshake(
                            self.uid, self.handle.world.uid, self.handle.world.desc, dfs=self.handle.dfs
                        )
                        uid, wuid, wdesc = conn.read_handshake(payload)
                    except gevent.Timeout as ex:
                        if ex != tout:
                            raise

                        conn.log.debug('timed out while reading handshake: %s', ex)
                        conn.close()
                        raise self.HandshakeFailed('Timed out')

                if not uid:
                    conn.log.debug('no uid after handshake')
                    conn.close()
                    raise self.HandshakeFailed('Got no uid')

            except self.ConnectFailed as ex:
                raise

            except Exception as ex:
                conn.log.error('handshake: unhandled exception: %s' % (ex, ))
                raise self.HandshakeFailed('Got unhandled exception: %s' % (ex, ))

            if wuid == self.handle.world.uid:
                conn.log.error('handshake: connected to ourselves')
                raise self.ConnectedToSelf()

            if 'skybit' not in conn.capabilities:
                peer = self.get_peer(wuid, wdesc, addr if outgoing else None)
                peer.set_state(Peer.NO_SKYBIT)
                conn.close()
                self.collect_peer(peer)
                return peer

            with gevent.Timeout(10) as tout:
                try:
                    peer = self.activate_connection(wuid, wdesc, conn, addr if outgoing else None)
                except gevent.Timeout as ex:
                    if ex != tout:
                        raise
                    conn.log.error('handshake: connection activation timed out')
                    raise self.HandshakeFailed()
                except Exception as ex:
                    conn.log.error('handshake: unhandled exception: %s', str(ex))
                    raise

            if uid != self.uid:
                peer.set_state(Peer.NO_RESOURCE)
                conn.close()
                if peer.conn == conn:
                    peer.conn = None
                self.collect_peer(peer)

            else:
                peer.set_state(Peer.CONNECTED)
                self.collect_peer(peer)

            return peer
        finally:
            self.conns_halfopen.pop(addr)

    def get_peer(self, uid, desc, outgoing_addr):
        peer = self.peers.get(uid)
        if not peer:
            peer = Peer(
                uid=uid, desc=desc, state=Peer.CONNECTING, state_ts=time.time(),
                ou_address=outgoing_addr
            )
            # forward address weight if it was set to peer
            # (e.g. seeders will have more weight if added later)
            if outgoing_addr is not None:
                peer.weight = outgoing_addr.weight
            self.peers.add(peer)
            self.collect_peer(peer)
        return peer

    def activate_connection(self, wuid, wdesc, conn, outgoing_addr):
        """
        Tricky step -- we need to activate exactly one connection to specific
        peer.

        If we are connecting each other in same time, this will occur:

        1. We make outgoing connection and other peer makes
        2. We send random number and peer sends
        3. We receive incoming conn from peer and compare it's number with our first connection number
        4. Same thing occurs in other peer
        5. We keep only and and same connection

        As a result we will have:
            - peer in self.peers collection
            - that peer will have connection (at conn) if activated here
        """
        peer = self.get_peer(wuid, wdesc, outgoing_addr)

        if peer.worker:
            print(peer.worker)
            conn.log.debug('Closing in favour of existing active worker')
            raise self.Deactivate('found better connection (existing, activated)', conn)
        else:
            peer.set_state(Peer.CONNECTING)

        # Forcibly collect peer, so we dont forget to do that in any "bad" case
        self.collect_peer(peer)

        # Proceed to activation procedure
        # We need to send activate message if we are connecting and wait for same
        # message if we are accepted this peer.
        clog = conn.log
        clog.debug('Activating...')

        if outgoing_addr:
            conn.votenum = random.randint(1, 2 ** 32 - 1)
            try:
                conn.send_message('ACTIVATE', conn.votenum)
            except conn.EPIPE:
                raise self.Deactivate('Broken pipe during sending ACTIVATE msg', conn)
        else:
            msgtype, payload = conn.get_message()
            if msgtype != 'ACTIVATE':
                clog.warning('activate: expected ACTIVATE message, got %r', msgtype)
                raise self.Deactivate('Got invalid message: %r' % (msgtype, ), conn)
            conn.votenum = payload

        # Well, we have votenum for this connection now
        if not peer.worker:
            # Yay! We are first to create activated worker on this peer. Do it now
            peer.worker = gevent.getcurrent()
            peer.conn = conn
            conn.log.debug('activate: first conn, activated')
        else:
            # Oops, other connection created worker while we were sending/receiving activation
            # messages. Compare ours votenum:

            assert peer.conn is not None

            while True:
                old_worker = peer.worker
                old_conn = peer.conn

                if conn.votenum > peer.conn.votenum:
                    # Next, we need to kill worker and wait until it dies
                    # Worker will catch this message and send either NO_ACTIVATE or DE_ACTIVATE
                    # message depending on it's stage

                    # IMPORTANT: do not log anything before killing or we could be switched out and
                    # worker could change (if 3rd connection tries to activate in the same time)

                    # Old worker may be in the middle of receiving something, so first acquire no_interrupt lock
                    with old_conn.no_interrupt_rcv:
                        pass

                    # Now if we wakeup other worker greenlet it will not be in critical non-interruptable
                    # section

                    assert not old_conn.no_interrupt_rcv.locked()

                    old_worker.kill(self.Deactivate('found better connection (new)', conn))

                    if peer.worker == old_worker:
                        assert peer.conn is None

                        peer.worker = gevent.getcurrent()
                        peer.conn = conn
                        conn.log.debug('activate: activated (dropped old connection)')
                        break
                else:
                    # No luck. Other connection is better. So just kill ourselves.
                    raise self.Deactivate('found better connection (existing)', conn)

        return peer

    def count_sent_bytes(self, cnt, payload=False):
        if payload:
            self._apply_avg_count(cnt, self._send_window_payload)
            self._apply_avg_count(cnt, self._send_window)
        else:
            self._apply_avg_count(cnt, self._send_window)

    def count_recv_bytes(self, cnt, payload=False):
        if payload:
            self._apply_avg_count(cnt, self._recv_window_payload)
            self._apply_avg_count(cnt, self._recv_window)
        else:
            self._apply_avg_count(cnt, self._recv_window)

    def _apply_avg_count(self, cnt, window):
        now = int(time.time())

        if not window:
            window.append([now, cnt])
            return

        last = window[-1]
        if last[0] == now:
            last[1] += cnt
            return

        window.append([now, cnt])

        while window[0][0] < (now - self.SPEED_WINDOW):
            window.popleft()

    def _get_speed(self, window):
        now = int(time.time())
        while window and window[0][0] < (now - self.SPEED_WINDOW):
            window.popleft()

        if not window:
            return 0.0

        if len(window) == 1:
            return float(window[0][1])

        return sum((pair[1] for pair in window)) / float(window[-1][0] - window[0][0])

    def get_dl_speed(self):
        return self._get_speed(self._recv_window)

    def get_ul_speed(self):
        return self._get_speed(self._send_window)
