from __future__ import print_function, division

import bisect
import gevent
import itertools
import logging
import msgpack
import threading
import random
import sys
import time

from gevent import socket
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network


HAS_BLOCKED_NETS = False

class TrackerError(Exception):
    """ Base tracker error """


class EntropyError(TrackerError):
    """ Unable to generate unique id """


class MallformedPacket(TrackerError):
    """ Unable to parse icoming packet """


class TTLOutOfWindow(Exception):
    """ Out of window """


class TTLWindow(object):
    def __init__(self, seconds, length, now):
        """
        Arguments:
            seconds:    length of ttl window in seconds
            length:     length of one block in seconds
            now:        current timestamp
        """
        assert seconds % length == 0
        assert isinstance(length, int)

        blocks = int(seconds // length)

        self.length = length
        self.blocks = tuple(set() for _ in range(blocks))  # all blocks
        self.blocks_count = blocks                         # count of blocks
        self.objs = {}                              # objects, obj => block

        self.start = int(now // length)             # starting block idx
        self.idx = self.start % self.blocks_count   # current block idx

    def remove(self, obj):
        if obj in self.objs:
            self.objs.pop(obj).discard(obj)
            return True
        return False

    def update(self, obj, ttl):
        block_idx = int(ttl) // self.length
        assert isinstance(block_idx, int)

        if block_idx < self.start:
            raise TTLOutOfWindow('Attempted to update block %d while we start at %d' % (
                block_idx, self.start
            ))

        if block_idx >= (self.start + self.blocks_count):
            raise TTLOutOfWindow('Attempted to update block %d, while we have %d' % (
                block_idx, self.start + self.blocks_count
            ))

        block_idx %= self.blocks_count

        block = self.blocks[block_idx]
        block.add(obj)

        if obj in self.objs:
            self.remove(obj)

        self.objs[obj] = block

    def iterate_outdated(self, now):
        now_block = int(now // self.length)
        outdated_blocks = now_block - self.start

        if not outdated_blocks:
            return

        i = self.idx

        # Say, we have 10 outdated blocks, 4 total and 2 current
        # We would outdate blocks 2, 3, 0 and 1 in that case.
        # We never iterate blocks twice (this is pointless). Instead, we
        # will tune self.idx if we have too much outdated blocks
        for i in range(i, min(outdated_blocks + i, self.blocks_count + i)):
            i = i % self.blocks_count

            block = self.blocks[i]

            if block:
                for obj in block:
                    yield obj
                    try:
                        del self.objs[obj]
                    except KeyError:
                        pass
                block.clear()
            self.idx += 1

        if outdated_blocks > self.blocks_count:
            # We have outdated all blocks
            # Tune idx to match current block properly
            self.idx = int(now // self.length)

        if self.idx >= self.blocks_count:
            self.idx = self.idx - self.blocks_count

        self.start = now_block


class Peer(object):
    TTL = 172800  # 48 hrs

    __slots__ = (
        'uid ttl ip ip6 fbip fbip6 ipp ip6p fbipp fbip6p blocked port '
        'has_fb has_bb seeds leechs dfs peer_types extended_announce_counters'.split()
    )

    def __init__(self, uid, ip, ip6, fbip, fbip6, blocked_nets, blocked_nets6, port, now):
        self.uid = uid
        self.update(ip, ip6, fbip, fbip6, blocked_nets, blocked_nets6, port, now)
        self.seeds = 0
        self.leechs = 0
        self.extended_announce_counters = False
        self.dfs = False
        self.peer_types = False

    def update_ttl(self, now):
        self.ttl = now + self.TTL

    def is_blocked(self, blocked_nets, blocked_nets6):
        if self.ip or self.fbip:
            for net in blocked_nets:
                if self.ip and IPv4Address(unicode(self.ip)) in net:
                    return True
                if self.fbip and IPv4Address(unicode(self.fbip)) in net:
                    return True

        if self.ip6 or self.fbip6:
            for net in blocked_nets6:
                if type(self.ip6) is tuple:
                    for ip6 in self.ip6:
                        if IPv6Address(unicode(ip6)) in net:
                            return True
                elif self.ip6 and IPv6Address(unicode(self.ip6)) in net:
                    return True

                if type(self.fbip6) is tuple:
                    for fbip6 in self.fbip6:
                        if IPv6Address(unicode(fbip6)) in net:
                            return True
                elif self.fbip6 and IPv6Address(unicode(self.fbip6)) in net:
                    return True

        return False


    def update(self, ip, ip6, fbip, fbip6, blocked_nets, blocked_nets6, port, now):
        self.update_ttl(now)

        if ip:
            self.ip = ip
            self.ipp = socket.inet_pton(socket.AF_INET, ip)
        else:
            self.ip = self.ipp = None

        if ip6:
            if isinstance(ip6, (list, tuple)):
                self.ip6 = tuple(ip6)
                self.ip6p = tuple(socket.inet_pton(socket.AF_INET6, ip6part) for ip6part in ip6)
            else:
                self.ip6 = ip6
                self.ip6p = socket.inet_pton(socket.AF_INET6, ip6)
        else:
            self.ip6 = self.ip6p = None

        if fbip:
            self.fbip = fbip
            self.fbipp = socket.inet_pton(socket.AF_INET, fbip)
        else:
            self.fbip = self.fbipp = None

        if fbip6:
            if isinstance(fbip6, (list, tuple)):
                self.fbip6 = tuple(fbip6)
                self.fbip6p = tuple(socket.inet_pton(socket.AF_INET6, fbip6part) for fbip6part in fbip6)
            else:
                self.fbip6 = fbip6
                self.fbip6p = socket.inet_pton(socket.AF_INET6, fbip6)
        else:
            self.fbip6 = self.fbip6p = None

        self.has_bb = bool(self.ip or self.ip6)
        self.has_fb = bool(self.fbip or self.fbip6)

        self.blocked = self.is_blocked(blocked_nets, blocked_nets6)

        self.port = port

    def set_extensions(self, extensions):
        self.extended_announce_counters = False
        self.peer_types = False
        self.dfs = False

        for ext in extensions:
            if ext == 'extended_announce_counters':
                self.extended_announce_counters = True
            elif ext == 'peer_types':
                self.peer_types = True
            elif ext == 'dfs':
                self.dfs = True

    @classmethod
    def from_state(cls, state, blocked_nets, blocked_nets6):
        if len(state) == 7:
            uid, ttl, ipp, ip6p, fbipp, fbip6p, port = state
            extended_announce_counters = False
            dfs = False
            peer_types = False
        else:
            state_version = state[0]

            if state_version == 1:
                _, uid, ttl, ipp, ip6p, fbipp, fbip6p, port, extended_announce_counters = state
                dfs = peer_types = False
            elif state_version == 2:
                _, uid, ttl, ipp, ip6p, fbipp, fbip6p, port, dfs, peer_types, extended_announce_counters = state
            else:
                raise Exception('Unknown state version: %r', state)

        peer = cls.__new__(cls)
        peer.uid = uid
        peer.ttl = ttl
        peer.port = port
        peer.seeds = 0
        peer.leechs = 0
        peer.dfs = dfs
        peer.peer_types = peer_types
        peer.extended_announce_counters = extended_announce_counters

        if ipp:
            peer.ip = socket.inet_ntop(socket.AF_INET, ipp)
            peer.ipp = ipp
        else:
            peer.ip = peer.ipp = None

        if ip6p:
            if isinstance(ip6p, (list, tuple)):
                peer.ip6 = tuple(socket.inet_ntop(socket.AF_INET6, ip6ppart) for ip6ppart in ip6p)
                peer.ip6p = tuple(ip6p)
            else:
                peer.ip6 = socket.inet_ntop(socket.AF_INET6, ip6p)
                peer.ip6p = ip6p
        else:
            peer.ip6 = peer.ip6p = None

        if fbipp:
            peer.fbip = socket.inet_ntop(socket.AF_INET, fbipp)
            peer.fbipp = fbipp
        else:
            peer.fbip = peer.fbipp = None

        if fbip6p:
            if isinstance(fbip6p, (list, tuple)):
                peer.fbip6 = tuple(socket.inet_ntop(socket.AF_INET6, fbip6ppart) for fbip6ppart in fbip6p)
                peer.fbip6p = tuple(fbip6p)
            else:
                peer.fbip6 = socket.inet_ntop(socket.AF_INET6, fbip6p)
                peer.fbip6p = fbip6p
        else:
            peer.fbip6 = peer.fbip6p = None

        peer.has_bb = peer.ip or peer.ip6
        peer.has_fb = peer.fbip or peer.fbip6

        peer.blocked = peer.is_blocked(blocked_nets, blocked_nets6)

        return peer

    def ips(self, typ, allow_v6):
        if typ == 'bb':
            if allow_v6:
                if self.ip6:
                    if type(self.ip6) is tuple:
                        return [(ip, self.port) for ip in self.ip6]
                    else:
                        return [(self.ip6, self.port)]
            if self.ip:
                return [(self.ip, self.port)]

        elif typ == 'fb':
            if allow_v6:
                if self.fbip6:
                    if type(self.fbip6) is tuple:
                        return [(ip, self.port) for ip in self.fbip6]
                    else:
                        return [(self.fbip6, self.port)]
                if self.dfs and self.ip6:
                    # If we are dfs peer and has NO fb6, but has bb6 -- return bb6 instead
                    if type(self.ip6) is tuple:
                        return [(ip, self.port) for ip in self.ip6]
                    else:
                        return [(self.ip6, self.port)]
                if self.fbip:
                    return [(self.fbip, self.port)]
            else:
                if self.fbip:
                    return [(self.fbip, self.port)]
                if self.dfs and self.ip:
                    # If we are dfs peer and has NO fb, but has bb -- return bb instead
                    return [(self.ip, self.port)]

        return []

    def __repr__(self):
        return '<Peer %s>' % (self.uid.encode('hex'), )

    def get_state(self):
        return (
            2, self.uid, self.ttl,
            self.ipp, self.ip6p, self.fbipp, self.fbip6p,
            self.port,
            self.dfs, self.peer_types,
            self.extended_announce_counters
        )


class PeerConnection(object):
    TTL = 300

    __slots__ = 'cid peer ttl'.split()

    def __init__(self, cid, peer, now):
        self.cid = cid
        self.peer = peer
        self.update(now)

    @classmethod
    def from_state(cls, state, peers):
        cid, peer_uid, ttl = state

        peer = peers.peers.get(peer_uid, None)
        if not peer:
            return None

        conn = cls.__new__(cls)
        conn.cid = cid
        conn.peer = peer
        conn.ttl = ttl
        return conn

    def update(self, now):
        self.ttl = now + self.TTL

    def get_state(self):
        return (self.cid, self.peer.uid, self.ttl)


class PeerConnections(object):
    __slots__ = 'by_cid', 'by_peer', 'log'

    def __init__(self, logger):
        self.by_cid = {}
        self.by_peer = {}
        self.log = logger

    def create(self, peer, now):
        conn = self.by_peer.get(peer)

        if not conn:
            for i in range(1):
                # cid = int((2 ** 32 - 1) * random.random())
                cid = int(peer.uid.encode('hex'), 16)
                if cid not in self.by_cid:
                    break
            else:
                raise EntropyError('Unable to generate random connection id')

            conn = PeerConnection(cid, peer, now)
            self.by_cid[conn.cid] = conn
        else:
            conn.update(now)

        self.by_peer[peer] = conn

        return conn

    def remove(self, conn):
        self.by_cid.pop(conn.cid, None)
        self.by_peer.pop(conn.peer, None)

    def clean(self, now):
        maxttl = now
        drop = []

        for conn in self.by_cid.itervalues():
            if conn.ttl < maxttl:
                drop.append(conn)

        dropped = 0
        for conn in drop:
            self.remove(conn)
            dropped += 1

        if drop:
            self.log.debug('Cleaned %d connections (%d left)', dropped, len(self.by_cid))

    def get_state(self):
        state = []
        for conn in self.by_cid.itervalues():
            state.append(conn.get_state())
        return state

    def set_state(self, state, peers):
        for conn_state in state:
            conn = PeerConnection.from_state(conn_state, peers)
            if not conn:
                continue
            self.by_cid[conn.cid] = conn
            self.by_peer[conn.peer] = conn


class Peers(object):
    __slots__ = 'peers', 'blocked_nets', 'blocked_nets6', 'log'

    def __init__(self, blocked_nets, logger):
        self.peers = {}
        self.blocked_nets = []
        self.blocked_nets6 = []
        for net in blocked_nets:
            if ':' in net:
                self.blocked_nets6.append(IPv6Network(unicode(net)))
            else:
                self.blocked_nets.append(IPv4Network(unicode(net)))
        self.log = logger

    def add(self, peer):
        self.peers[peer.uid] = peer
        return peer

    def add_raw(self, uid, ip, ip6, fbip, fbip6, port, now, extensions):
        if uid not in self.peers:
            peer = Peer(uid, ip, ip6, fbip, fbip6, self.blocked_nets, self.blocked_nets6, port, now)
            peer.set_extensions(extensions)
            return self.add(peer)
        else:
            peer = self.peers[uid]
            peer.update(ip, ip6, fbip, fbip6, self.blocked_nets, self.blocked_nets6, port, now)
            peer.set_extensions(extensions)
            return peer

    def remove(self, peer):
        try:
            del self.peers[peer.uid]
            return True
        except KeyError:
            return False

    def clean(self, now):
        maxttl = now
        drop = []
        qdrop = []

        for peer in self.peers.itervalues():
            if peer.ttl < maxttl:
                drop.append(peer)
            elif peer.ttl - Peer.TTL + 300 < maxttl:  # if peer has come more than 5 minutes ago last time
                if not (peer.seeds or peer.leechs):
                    qdrop.append(peer)

        dropped, qdropped, missing = 0, 0, 0
        for peer in drop:
            if self.remove(peer):
                dropped += 1
            else:
                missing += 1

        for peer in qdrop:
            if self.remove(peer):
                qdropped += 1
            else:
                missing += 1

        if drop or qdrop:
            self.log.debug(
                'Cleaned %d peers and %d quick (%d missing, %d left)',
                dropped, qdropped, missing, len(self.peers)
            )

    def get_state(self):
        state = []
        for peer in self.peers.itervalues():
            state.append(peer.get_state())
        return state

    def set_state(self, state):
        blocked = 0
        for peer_state in state:
            peer = Peer.from_state(peer_state, self.blocked_nets, self.blocked_nets6)
            if peer.blocked:
                blocked += 1
            self.add(peer)

        return blocked


class HashState(object):
    LEECHING = 1
    SEEDING = 2
    STOPPED = 3


class HashInfo(object):
    __slots__ = 'hashdata peers_seeding peers_leeching ptr_seeding ptr_leeching last_leecher_ts'.split()

    def __init__(self, hashdata):
        self.hashdata = hashdata
        self.peers_seeding = []
        self.peers_leeching = []
        self.ptr_seeding = 0
        self.ptr_leeching = 0
        self.last_leecher_ts = 0  # timestamp we added/updated last leecher

    def add_peer(self, peer, state):
        if state == HashState.LEECHING:
            if self._bisect_add(self.peers_leeching, peer):
                peer.leechs += 1
            if self._bisect_remove(self.peers_seeding, peer):
                peer.seeds -= 1

            # Once we update peers_leeching struct with new peer - update last access time
            self.last_leecher_ts = int(time.time())
        elif state == HashState.SEEDING:
            if self._bisect_add(self.peers_seeding, peer):
                peer.seeds += 1
            if self._bisect_remove(self.peers_leeching, peer):
                peer.leechs -= 1

    def _bisect_remove(self, lst, item):
        idx = bisect.bisect_left(lst, item)
        if len(lst) > idx and lst[idx] == item:
            del lst[idx]
            return True
        return False

    def _bisect_add(self, lst, item):
        idx = bisect.bisect_left(lst, item)
        if len(lst) == idx or lst[idx] != item:
            lst.insert(idx, item)
            return True
        return False

    def remove_peer(self, peer):
        if self._bisect_remove(self.peers_seeding, peer):
            peer.seeds -= 1
        if self._bisect_remove(self.peers_leeching, peer):
            peer.leechs -= 1

    def get_round_robin_population(self, population, k, ptr):
        global HAS_BLOCKED_NETS
        if HAS_BLOCKED_NETS:
            population = [peer for peer in population if not peer.blocked]

        population_len = len(population)
        if k > population_len:
            return population, 0

        if isinstance(population, set):
            population = list(population)

        if ptr > population_len:
            # This could happen if our population reduced in size and pointer
            # does not fit it. Crop pointer to 0 in that case.
            ptr = 0

        newptr = ptr + k

        result = population[ptr:newptr]
        if newptr >= population_len:
            newptr -= population_len
            result += population[0:newptr]

        return result, newptr

    def get_announce_peers(self, peer, num_seeders, num_leechers):
        ret = set()

        # return more seeders if we have no enough leechers
        if num_leechers > len(self.peers_leeching):
            num_seeders += num_leechers - len(self.peers_leeching)

        return_seeders, self.ptr_seeding = self.get_round_robin_population(
            self.peers_seeding, num_seeders, self.ptr_seeding
        )
        return_leechers, self.ptr_leeching = self.get_round_robin_population(
            self.peers_leeching, num_leechers, self.ptr_leeching
        )

        seeders = set(return_seeders)
        leechers = set(return_leechers)

        seeders.discard(peer)
        leechers.discard(peer)

        return seeders, leechers

    def get_state(self):
        return (
            self.hashdata.decode('hex'),
            tuple(p.uid for p in self.peers_seeding),
            tuple(p.uid for p in self.peers_leeching),
            self.last_leecher_ts
        )

    @classmethod
    def from_state(cls, state, peers):
        hashinfo = cls(state[0].encode('hex'))
        for uid in state[1]:
            peer = peers.peers.get(uid, None)
            if peer:
                peer.seeds += 1
                hashinfo._bisect_add(hashinfo.peers_seeding, peer)
        for uid in state[2]:
            peer = peers.peers.get(uid, None)
            if peer:
                peer.leechs += 1
                hashinfo._bisect_add(hashinfo.peers_leeching, peer)
        if len(state) >= 4:
            hashinfo.last_leecher_ts = state[3]
        return hashinfo


class Hashes(object):
    def __init__(self):
        self.hashes = {}

    def update(self, peer, hashdata, state):
        if state != HashState.STOPPED:
            try:
                hashinfo = self.hashes[hashdata]
            except KeyError:
                hashinfo = self.hashes[hashdata] = HashInfo(hashdata)

            hashinfo.add_peer(peer, state)
        else:
            if hashdata in self.hashes:
                hashinfo = self.hashes[hashdata]
                self.remove(hashinfo, peer)
            else:
                return

        return hashinfo

    def remove(self, hashinfo, peer):
        hashinfo.remove_peer(peer)
        if not (hashinfo.peers_seeding or hashinfo.peers_leeching):
            if hashinfo.hashdata in self.hashes:
                del self.hashes[hashinfo.hashdata]

    def get_state(self):
        state = []
        for hashinfo in self.hashes.itervalues():
            state.append(hashinfo.get_state())
        return state

    def set_state(self, state, peers):
        loaded, skipped = 0, 0
        for hashinfo_state in state:
            hashinfo = HashInfo.from_state(hashinfo_state, peers)
            if hashinfo.peers_leeching or hashinfo.peers_seeding:
                self.hashes[hashinfo.hashdata] = hashinfo
                loaded += 1
            else:
                skipped += 1

        return loaded, skipped


class Tracker(object):
    """
        Connect:
        REQ: (cid, action, tid, uid, ip4, ip6, fbip4, fbip6, port)  # action = 0
        REP: (tid, action, cid)                                     # action = 0

        Announce:
        REQ: (cid, action, tid, state, hash, net)                   # action = 10
        REP: (tid, action, interval, seeders, leechers, (peers))    # action = 10, for LEECHING
        REP: (tid, action, interval)                                # action = 10, for SEEDING
        REP: (tid, action)                                          # action = 10, for STOPPED

        States: start(1), stop(2), leeching(3), seeding(4)
    """

    INITIAL_CID = 1736820976
    INITIAL_CID2 = 1964596860  # supports extended announce response with fb peers counting
    INITIAL_CID3 = 4079332072  # protocol version 3 :)

    ACTION_CONNECT = 0
    ACTION_ANNOUNCE = 10
    ACTION_STARTUP = 20
    ACTION_SHUTDOWN = 30

    RESULT_OK = 0
    RESULT_ERR = 1

    NET_AUTO = 0
    NET_BB_ONLY = 1
    NET_FB_ONLY = 2

    CLEAN_PEERS_EVERY = 1800
    CLEAN_PEERS_CONNS_EVERY = 600

    def __init__(self, interval, interval_leech, return_peers, return_seeders_ratio, blocked_nets, logger):
        assert isinstance(interval, (list, tuple))
        assert len(interval) == 2

        self.msg_packer = msgpack.Packer()

        self.interval = interval[0]
        self.interval_rnd = interval[1] - interval[0]
        self.interval_leech = interval_leech[0]
        self.interval_leech_rnd = interval_leech[1] - interval_leech[0]

        self.return_peers = return_peers                                 # e.g. 25
        self.return_seeders = int(return_seeders_ratio * return_peers)   # e.g. 25 * 0.12 == 3
        self.return_leechers = self.return_peers - self.return_seeders   # e.g. 25 - 3 == 22

        self.ttl = TTLWindow(interval[1] + 120, 60, int(time.time()))

        self.log = logger
        self.peers = Peers(blocked_nets=blocked_nets, logger=self.log.getChild('peers'))
        self.conns = PeerConnections(logger=self.log.getChild('peerconns'))
        self.hashes = Hashes()

        self.ttl_clean_ts = 0

        self._stats = [0, 0, 0, 0]  # pings, connects, announces, invalid cid

        self.now = int(time.time())
        # self._now_maker_grn = gevent.spawn(self._now_maker)
        self._now_maker_thr = threading.Thread(target=self._now_maker)
        self._now_maker_thr.daemon = True
        self._now_maker_thr.start()

        self.peers_clean_ts = self.now
        self.conns_clean_ts = self.now

        self.blocked_nets = blocked_nets
        if self.blocked_nets:
            global HAS_BLOCKED_NETS
            HAS_BLOCKED_NETS = True

    def _now_maker(self):
        while True:
            self.now = int(time.time())
            # gevent.sleep(1)
            time.sleep(1)

    def get_ips_for_peer(self, peer, net, peers, is_seeders, peer_types):
        ret = []

        ips_type, allow_v6 = None, True

        # We prefer v4 only if peer has v4 address and no v6
        # In all other cases (has both, has none, has only v6) -- we prefer v6

        if net == self.NET_BB_ONLY or (net == self.NET_AUTO and not peer.has_fb):
            # Choose backbone peers
            ips_type = 'bb'

            if peer.ip and not peer.ip6:
                # Prefer ipv4 if peer has known v4 address and no v6
                allow_v6 = False

        elif net == self.NET_FB_ONLY or (net == self.NET_AUTO and peer.has_fb):
            # Choose fastbone peers
            ips_type = 'fb'

            if peer.fbip and not peer.fbip6:
                # Prefer ipv4 if peer has known v4 address and no v6
                allow_v6 = False

        else:
            return ret

        for p in peers:
            ips = p.ips(ips_type, allow_v6)
            if not peer_types:
                ret.extend(ips)
            else:
                ret.extend({'ips': ip, 'dfs': p.dfs, 'seeder': is_seeders} for ip in ips)

        return ret

    def clean_outdated(self, now):
        for hashdata, peeruid in self.ttl.iterate_outdated(now):
            try:
                hashinfo = self.hashes.hashes[hashdata]
                peer = self.peers.peers[peeruid]
            except KeyError:
                continue
            self.hashes.remove(hashinfo, peer)

    def announce(self, tid, peer, state, hash, net, now, v3=False):
        if v3:
            peer.update_ttl(now)

        if state == HashState.STOPPED and hash == 'CLEAN':
            self.log.info('Peer %s: clean all', peer.uid.encode('hex'))

            # https://st.yandex-team.ru/SKYDEV-2322
            #
            # for hashinfo in self.hashes.hashes.values():
            #     ttl_obj = (hashinfo.hashdata, peer.uid)
            #     self.ttl.remove(ttl_obj)
            #     self.hashes.remove(hashinfo, peer)

            return tid, self.RESULT_OK

        hashinfo = self.hashes.update(peer, hash, state)

        if now - self.ttl_clean_ts > 1:
            self.clean_outdated(now)
            self.ttl_clean_ts = now

        # TODO: put conns/peers clean times into config
        if self.peers_clean_ts + self.CLEAN_PEERS_EVERY < now:
            self.peers.clean(now)
            self.peers_clean_ts = now

        if self.conns_clean_ts + self.CLEAN_PEERS_CONNS_EVERY < now:
            self.conns.clean(now)
            self.conns_clean_ts = now

        if self.interval_rnd:
            interval = self.interval + int(random.random() * self.interval_rnd)
        else:
            interval = self.interval

        if hashinfo is None:
            assert state == HashState.STOPPED, 'Got empty hashinfo!'
        else:
            ttl_obj = (hashinfo.hashdata, peer.uid)

        if state == HashState.LEECHING:
            interval = self.interval_leech
            n_leechers = len(hashinfo.peers_leeching)
            n_seeders = len(hashinfo.peers_seeding)

            seeders, leechers = hashinfo.get_announce_peers(peer, self.return_seeders, self.return_leechers)

            grab_peer_types = peer.peer_types or v3  # grab peer type if asked to or always for proto v3

            seed_peers = self.get_ips_for_peer(peer, net, seeders, is_seeders=True, peer_types=grab_peer_types)
            leech_peers = self.get_ips_for_peer(peer, net, leechers, is_seeders=False, peer_types=grab_peer_types)

            return_peers = seed_peers + leech_peers

            self.ttl.update(ttl_obj, now + interval)

            if v3:
                return (
                    tid, self.RESULT_OK, interval,
                    n_seeders, n_leechers,
                    return_peers
                )
            elif peer.peer_types:
                return (
                    tid, self.ACTION_ANNOUNCE, interval,
                    ('peer_types', ),
                    n_seeders, n_leechers,
                    return_peers
                )
            elif peer.extended_announce_counters:
                nr_seeders = len(seed_peers)
                nr_leechers = len(leech_peers)

                return (
                    tid, self.ACTION_ANNOUNCE, interval,
                    n_seeders, n_leechers,
                    nr_seeders, nr_leechers,
                    return_peers
                )
            else:
                return (
                    tid, self.ACTION_ANNOUNCE, interval,
                    n_seeders, n_leechers,
                    return_peers
                )
        elif state == HashState.SEEDING:
            self.ttl.update(ttl_obj, now + interval)
            if v3:
                return tid, self.RESULT_OK, interval
            else:
                return tid, self.ACTION_ANNOUNCE, interval
        elif state == HashState.STOPPED:
            if hashinfo is not None:
                self.ttl.remove(ttl_obj)
                self.hashes.remove(hashinfo, peer)

            if v3:
                return tid, self.RESULT_OK
            else:
                return tid, self.ACTION_ANNOUNCE
        else:
            assert 0, 'Invalid state %r' % (state, )

    def _msgpack_load(self, data):
        return msgpack.loads(data)

    def _msgpack_pack(self, data):
        return self.msg_packer.pack(data)

    def hash_info(self, hash, req):
        hashstr = hash.encode('hex')

        hashinfo = self.hashes.hashes.get(hashstr, None)
        if hashinfo is None:
            return {
                'v': 1,
                'req': req,
                'info': {}
            }

        return {
            'v': 1,
            'req': req,
            'info': {
                hash: {
                    'leechers': len(hashinfo.peers_leeching),
                    'seeders': len(hashinfo.peers_seeding),
                    'atime': hashinfo.last_leecher_ts
                }
            }
        }

    def handle_packet(self, data, ip, port):
        stats = self._stats

        pip = ip

        try:
            now = self.now
            msg = data

            cid = msg[0]
            action = msg[1]

            if cid == 'PI' and action == 'NG':
                stats[0] += 1
                return 'PONG'

            if cid == 'IN' and action == 'FO':
                return self.hash_info(msg[2], msg[3])

            if cid in (self.INITIAL_CID, self.INITIAL_CID2):
                stats[1] += 1
                assert action == self.ACTION_CONNECT, (
                    'Action should be 0 with initial conn packet, we got %r' % (action, )
                )
                tid, uid, ip, ip6, fbip, fbip6, port = msg[2:9]
                if cid == self.INITIAL_CID2:
                    extensions = msg[9]
                else:
                    extensions = []

                if not ip and not ip6:
                    if ':' in pip:
                        if pip.startswith('::ffff:'):
                            ip = pip[7:]
                        else:
                            ip6 = pip
                    else:
                        ip = pip

                peer = self.peers.add_raw(uid, ip, ip6, fbip, fbip6, port, now, extensions)
                conn = self.conns.create(peer, now)
                return tid, action, conn.cid

            elif cid in (self.INITIAL_CID3, ):
                # Full proto v3 here
                if action == self.ACTION_CONNECT:
                    stats[1] += 1

                    tid, pid, (ip, ip6, fbip, fbip6), port = msg[2:6]
                    extensions = msg[6]

                    if not ip and not ip6:
                        if ':' in pip:
                            if pip.startswith('::ffff:'):
                                ip = pip[7:]
                            else:
                                ip6 = pip
                        else:
                            ip = pip

                    peer = self.peers.add_raw(pid, ip, ip6, fbip, fbip6, port, now, extensions)
                    return tid, self.RESULT_OK, 3600
                elif action == self.ACTION_ANNOUNCE:
                    stats[2] += 1
                    tid, pid, hash, state, net = msg[2:]

                    if pid in self.peers.peers:
                        peer = self.peers.peers[pid]
                    else:
                        return tid, self.RESULT_ERR, 'Unknown peer'

                    return self.announce(tid, peer, state, hash, net, now, v3=True)
                else:
                    self.log.warning('Invalid action %r, ignoring', action)
                    tid = 0
                    return tid, self.RESULT_ERR, 'Invalid action %r' % (action, )

            elif action == self.ACTION_ANNOUNCE:
                try:
                    conn = self.conns.by_cid[cid]
                except KeyError:
                    stats[3] += 1
                    return

                stats[2] += 1

                conn.update(now)

                tid, state, hash, net = msg[2:]
                return self.announce(tid, conn.peer, state, hash, net, now)
            else:
                self.log.warning('Invalid action %r, ignoring', action)
                return

        except Exception as ex:
            self.log.warning('Failed to handle packet from %s:%s' % (pip, port), exc_info=ex)
            if not isinstance(ex, TrackerError):
                ei = sys.exc_info()
                raise MallformedPacket, '%s: %s' % (type(ex), str(ex)), ei[2]  # noqa
            raise

    def get_state(self):
        self.log.debug('Generating state...')

        ts = time.time()

        state = {
            'version': 1,
            'peers': self.peers.get_state(),
            'connections': self.conns.get_state(),
            'hashes': self.hashes.get_state(),
        }

        self.log.info(
            'Generated state in %0.4fs: %d peers, %d connections, %d hashes',
            time.time() - ts,
            len(state['peers']),
            len(state['connections']),
            len(state['hashes'])
        )

        return state

    def set_state(self, state):
        self.log.debug('Restoring state...')

        if not isinstance(state, dict):
            self.log.warning('State should be a dictionary')
            return False

        if state.get('version', None) != 1:
            self.log.warning('Dont know how to load state version %r', state.get('version', None))
            return False

        ts1 = time.time()

        blocked = self.peers.set_state(state['peers'])
        self.log.debug('Restored %d peers, blocked %d peers in %0.4fs', len(state['peers']), blocked, time.time() - ts1)

        ts2 = time.time()
        self.conns.set_state(state['connections'], self.peers)
        self.log.debug('Restored %d connections in %0.4fs', len(state['connections']), time.time() - ts2)

        ts2 = time.time()
        loaded, skipped = self.hashes.set_state(state['hashes'], self.peers)
        self.log.debug('Restored %d hashes, skipped %d hashes in %0.4fs', loaded, skipped, time.time() - ts2)

        if self.interval_rnd:
            interval = self.interval + self.interval_rnd
        else:
            interval = self.interval

        now = int(time.time())
        list(self.ttl.iterate_outdated(now))

        ts3 = time.time()
        ttl_objects = 0
        for hashinfo in self.hashes.hashes.itervalues():
            peers = hashinfo.peers_leeching + hashinfo.peers_seeding
            for peer in peers:
                ttl_obj = (hashinfo.hashdata, peer.uid)
                self.ttl.update(ttl_obj, now + interval)
                ttl_objects += 1

        self.log.debug('Restored %d ttl objects in %0.4fs', ttl_objects, time.time() - ts3)
        self.log.info('Restored state in %0.4fs', time.time() - ts1)

        return True

    def get_stats(self, describe=False):
        if describe:
            return (
                ('peers', 'total', 'count', [1]),
                ('conns', 'total', 'count', [1]),
                ('hashes', 'total', 'count', [1]),
                ('ltseeds', 'total', 'count', [1]),
                ('ltleechs', 'total', 'count', [1]),
                ('pings', 'counter', 'rps', [1, 60, 300, 3600]),
                ('connect attempts', 'counter', 'rps', [1, 60, 300, 3600]),
                ('announces', 'counter', 'rps', [1, 60, 300, 3600]),
                ('invalid cid', 'counter', 'rps', [1, 60, 300, 3600]),
            )

        total_seeds = total_leechs = 0
        for p in self.peers.peers.values():
            total_seeds += p.seeds
            total_leechs += p.leechs

        return (
            len(self.peers.peers),
            len(self.conns.by_cid),
            len(self.hashes.hashes),
            total_seeds,
            total_leechs,
            self._stats[0],
            self._stats[1],
            self._stats[2],
            self._stats[3],
        )
