# coding: utf-8
from __future__ import print_function

import itertools
import collections
import datetime
import gzip
import bisect
import socket
import operator
import struct
import gc

import msgpack
import requests
import pymongo
import netaddr
from tabulate import tabulate

from agent import application
from agent import topology
from agent import rpc

MONGO_URI = 'mongodb://hb-dbs01.search.yandex.net,hb-dbs02.search.yandex.net,hb-dbs03.search.yandex.net,hb-dbs04.search.yandex.net,hb-dbs05.search.yandex.net/heartbeat?connectTimeoutMS=30000&socketTimeoutMS=60000'  # noqa

RACKTABLES_VRF_BACKBONE_URL = 'https://ro.racktables.yandex.net/export/networklist.php?report=vrf-bb'
RACKTABLES_VRF_FASTBONE_URL = 'https://ro.racktables.yandex.net/export/networklist.php?report=vrf-fb'
RACKTABLES_VLAN_URL = 'https://ro.racktables.yandex.net/export/networklist.php?report=vlans'
RACKTABLES_NETS_URL = 'https://racktables.yandex.net/export/nets-by-project.php'
RACKTABLES_MTN_URL = 'https://racktables.yandex.net/export/vm-projects2.txt'


def iter_hosts():
    client = pymongo.MongoReplicaSetClient(
        MONGO_URI,
        replicaSet='heartbeat',
        read_preference=pymongo.ReadPreference.SECONDARY
    )

    it = client['heartbeat']['hostinfo'].find(
        {'last_update': {'$gt': datetime.datetime.now() - datetime.timedelta(hours=24 * 7)}},
        {'host': 1, '_id': 0}
    )
    for x in it:
        yield x['host']


def iter_hosts_data():
    with open(str(topology.get_path()), "rb") as stream:
        unpacker = msgpack.Unpacker(gzip.GzipFile(fileobj=stream))
        for row in unpacker:
            yield row


def iter_racktables_response(url):
    resp = requests.get(url, stream=True)
    for line in resp.iter_lines():
        if line:
            yield line.split('\t')


class IPNetwork(object):

    IPV4_WIDTH = 32
    IPV4_MAX_INT = 2 ** IPV4_WIDTH - 1

    IPV6_WIDTH = 128
    IPV6_MAX_INT = 2 ** IPV6_WIDTH - 1

    def __init__(self, network):
        self._network_int, self._prefix_int, self._width, self._max_int = self._parse_network(network)
        self._hostmask_int = (1 << (self._width - self._prefix_int)) - 1

    @staticmethod
    def _ipv4_to_int(addr):
        return struct.unpack("!I", socket.inet_pton(socket.AF_INET, addr))[0]

    @staticmethod
    def _ipv6_to_int(addr):
        hi, lo = struct.unpack('!QQ', socket.inet_pton(socket.AF_INET6, addr))
        return (hi << 64) | lo

    def _parse_network(self, network):
        addr, prefix = network.split("/")
        try:
            return self._ipv6_to_int(addr), int(prefix), self.IPV6_WIDTH, self.IPV6_MAX_INT
        except socket.error:
            return self._ipv4_to_int(addr), int(prefix), self.IPV4_WIDTH, self.IPV4_MAX_INT

    @classmethod
    def addr_to_int(cls, addr):
        try:
            return cls._ipv6_to_int(addr)
        except socket.error:
            return cls._ipv4_to_int(addr)

    @classmethod
    def extract_mtn_project(cls, addr):
        return hex(int((cls.addr_to_int(addr) & 0xFFFF00000000) >> 32))[2:]

    @classmethod
    def get_addr_family(cls, addr):
        try:
            cls._ipv6_to_int(addr)
        except socket.error:
            return socket.AF_INET
        else:
            return socket.AF_INET6

    @property
    def first(self):
        return self._network_int & (self._max_int ^ self._hostmask_int)

    @property
    def last(self):
        return self._network_int | self._hostmask_int

    def to_range(self):
        return (self.first, self.last)


class NetworkSet(object):

    Entry = collections.namedtuple("Entry", ("range_table", "value_table"))
    Range = collections.namedtuple("Range", ("start", "end"))

    def __init__(self, network_list):
        self._root = self._build_root(network_list)

    @classmethod
    def _build_root(cls, network_list):
        root = cls.Entry(range_table=[], value_table=[])
        networks = (
            (cls.Range(*IPNetwork(network).to_range()), value)
            for network, value in network_list
        )
        it = iter(sorted(networks, key=operator.itemgetter(0)))
        while True:
            try:
                next_range, next_value = next(it)
            except StopIteration:
                break
            if not root.range_table or root.range_table[-1].end < next_range.start:
                # networks don't intersects or empty, simply add new one
                root.range_table.append(next_range)
                root.value_table.append(cls.Entry(range_table=[], value_table=[]))
            else:
                assert root.range_table[-1].start <= next_range.start, "list not sorted"
                root.range_table[-1] = cls.Range(
                    root.range_table[-1].start, max(root.range_table[-1].end, next_range.end))
            # our tree has two levels
            last_entry = root.value_table[-1]
            last_entry.range_table.append(next_range)
            last_entry.value_table.append(next_value)
        return root

    @staticmethod
    def _belong_to_range(addr_int, addr_range):
        return addr_range.start <= addr_int <= addr_range.end

    @classmethod
    def _find_position(cls, addr_int, range_table):
        pos = bisect.bisect_left(range_table, (addr_int, ))
        if pos > 0:
            if cls._belong_to_range(addr_int, range_table[pos - 1]):
                return pos - 1
        if pos < len(range_table):
            if cls._belong_to_range(addr_int, range_table[pos]):
                return pos
        return None

    @classmethod
    def _scan_position(cls, addr_int, range_table):
        for pos in reversed(xrange(len(range_table))):
            if cls._belong_to_range(addr_int, range_table[pos]):
                return pos
        return None

    @classmethod
    def _get_value(cls, addr_int, entry, scan=False):
        if scan:
            position = cls._scan_position(addr_int, entry.range_table)
        else:
            position = cls._find_position(addr_int, entry.range_table)
        return entry.value_table[position] if position is not None else None

    def get(self, addr):
        if isinstance(addr, netaddr.IPAddress):
            addr = str(addr)
        addr_int = IPNetwork.addr_to_int(addr)
        child_entry = self._get_value(addr_int, self._root)
        return self._get_value(addr_int, child_entry, scan=True) if child_entry is not None else None


class Context(object):

    UsedNetwork = collections.namedtuple('UsedNetwork', ('network', 'vrf', 'vlan', 'macro', 'mtn'))

    def __init__(self):
        self.app = application.Application()
        self.app.register(rpc.RpcClient())

        self.bare_missing_hosts = set()
        self.bare_found_hosts = set()
        self.virtual_missing_hosts = set()
        self.virtual_found_hosts = set()

        self.bare_vlans = collections.defaultdict(set)
        self.virtual_vlans = collections.defaultdict(set)
        self.bare_vrfs = collections.defaultdict(set)
        self.virtual_vrfs = collections.defaultdict(set)

        self.addresses_on_host = {}

        self.macro_per_project_id = {}
        self.network_index = None

        self.used_networks = set()

    def _on_bare_iface_exists(self, iface):
        self.bare_found_hosts.add(iface.host)
        if iface.vrf:
            self.bare_vrfs[iface.vrf].add(iface.host)
        elif iface.vlan:
            self.bare_vlans[iface.vlan].add(iface.host)

    def _on_virtual_iface_exists(self, iface):
        self.virtual_found_hosts.add(iface.host)
        if iface.vrf:
            self.virtual_vrfs[iface.vrf].add(iface.host)
        elif iface.vlan:
            self.virtual_vlans[iface.vlan].add(iface.host)

    def _on_bare_iface_missing(self, iface):
        self.bare_missing_hosts.add(iface.host)

    def _on_virtual_iface_missing(self, iface):
        self.virtual_missing_hosts.add(iface.host)

    def _collect_hosts_from_heartbeat(self):
        topology_tree = self.app.run_sync(lambda: topology.tree(self.app))
        hosts = set(iter_hosts())
        for iface in topology_tree.interfaces():
            if iface.host in hosts:
                (self._on_virtual_iface_exists if iface.is_virtual else self._on_bare_iface_exists)(iface)
                self.addresses_on_host[iface.host] = []
            else:
                (self._on_virtual_iface_missing if iface.is_virtual else self._on_bare_iface_missing)(iface)

    def _collect_addresses_on_hosts(self):
        for host_data in iter_hosts_data():
            if host_data["fqdn"] in self.addresses_on_host:
                for interface_data in host_data["interfaces"]:
                    self.addresses_on_host[host_data["fqdn"]].extend(
                        (version, netaddr.IPAddress(address))
                        for version, address in (
                            (4, interface_data["ipv4addr"]),
                            (6, interface_data["ipv6addr"])
                        )
                        if address != "unknown"
                    )

    def _collect_from_racktables(self):
        network_map = collections.defaultdict(lambda: {
            "vlan": None,
            "vrf": None,
            "macro": None,
            "network": None,
            "mtn": None
        })

        vrf_networks = netaddr.IPSet()
        for network, vrf in iter_racktables_response(RACKTABLES_VRF_BACKBONE_URL):
            if not network.startswith('fd00:'):
                vrf_networks.add(network)
                network_map[network].update(
                    vrf=vrf,
                    network=netaddr.IPNetwork(network)
                )
        for network, vrf in iter_racktables_response(RACKTABLES_VRF_FASTBONE_URL):
            if not network.startswith('fd00:'):
                vrf_networks.add(network)
                network_map[network].update(
                    vrf=vrf,
                    network=netaddr.IPNetwork(network)
                )

        for network, vlan in iter_racktables_response(RACKTABLES_VLAN_URL):
            if not network.startswith('fd00:') and network not in vrf_networks:
                network_map[network].update(
                    vlan=int(vlan),
                    network=netaddr.IPNetwork(network)
                )

        for network, macro, _, _, _, _ in iter_racktables_response(RACKTABLES_NETS_URL):
            if network == "CIDR":
                continue
            network_map[network].update(
                macro=macro,
                network=netaddr.IPNetwork(network)
            )

        for macro, project_id, _ in iter_racktables_response(RACKTABLES_MTN_URL):
            self.macro_per_project_id[project_id] = macro

        mtn_backbone_block = IPNetwork('2a02:6b8:c00::/40').to_range()
        mtn_fastbone_block = IPNetwork('2a02:6b8:fc00::/48').to_range()
        for x in network_map.itervalues():
            x['mtn'] = (
                x['network'].version == 6
                and (
                    mtn_backbone_block[0] <= x['network'].first <= x['network'].last <= mtn_backbone_block[1]
                    or mtn_fastbone_block[0] <= x['network'].first <= x['network'].last <= mtn_fastbone_block[1]
                )
            )

        self.network_index = NetworkSet((network, self.UsedNetwork(**kwargs))
                                        for network, kwargs in network_map.items())

    def _collect_used_networks(self):
        mtn_project_ids = set()

        def extract_from_host(host):
            for version, address in self.addresses_on_host.get(host, ()):
                info = self.network_index.get(address)
                if info is not None:
                    if not info.mtn:
                        self.used_networks.add(info)
                    elif version == 6:
                        mtn_project_ids.add(IPNetwork.extract_mtn_project(str(address)))

        all_vrfs = set(self.bare_vrfs) | set(self.virtual_vrfs)
        all_vlans = set(self.bare_vlans) | set(self.virtual_vlans)

        for vrf in all_vrfs:
            for host in (self.bare_vrfs[vrf] | self.virtual_vrfs[vrf]):
                extract_from_host(host)

        for vlan in all_vlans:
            for host in (self.bare_vlans[vlan] | self.virtual_vlans[vlan]):
                extract_from_host(host)

        for project_id in mtn_project_ids:
            macro = self.macro_per_project_id.get(project_id)
            if macro is not None:
                self.used_networks.add(self.UsedNetwork(None, None, None, macro, True))

    def collect(self):
        self._collect_hosts_from_heartbeat()
        self._collect_addresses_on_hosts()
        self._collect_from_racktables()
        self._collect_used_networks()


def print_host_report(context):
    found_hosts = context.bare_found_hosts | context.virtual_found_hosts
    missing_hosts = context.bare_missing_hosts | context.virtual_missing_hosts

    total_bare_hosts = len(context.bare_found_hosts) + len(context.bare_missing_hosts)
    total_virtual_hosts = len(context.virtual_found_hosts) + len(context.virtual_missing_hosts)
    total_hosts = total_bare_hosts + total_virtual_hosts
    assert len(found_hosts) + len(missing_hosts) == total_hosts

    print("Hosts report:")
    print(tabulate([
        [
            'bare',
            len(context.bare_found_hosts),
            len(context.bare_found_hosts) * 100.0 / total_bare_hosts,
            len(context.bare_missing_hosts),
            len(context.bare_missing_hosts) * 100.0 / total_bare_hosts,
            total_bare_hosts
        ],
        [
            'virtual',
            len(context.virtual_found_hosts),
            len(context.virtual_found_hosts) * 100.0 / total_virtual_hosts,
            len(context.virtual_missing_hosts),
            len(context.virtual_missing_hosts) * 100.0 / total_virtual_hosts,
            total_virtual_hosts
        ],
        [
            'total',
            len(found_hosts),
            len(found_hosts) * 100.0 / total_hosts,
            len(missing_hosts),
            len(missing_hosts) * 100.0 / total_hosts,
            total_hosts
        ]
    ], headers=['type', 'found hosts', '%', 'missing hosts', '%', 'total']))
    print('')


def print_vlan_report(context):
    print("VLAN report (where skynet exists):")
    print(tabulate(sorted((
        [tp, vlan, count] for tp, vlan, count in itertools.chain(
            (("bare", vlan, len(found_hosts)) for vlan, found_hosts in context.bare_vlans.items()),
            (("virtual", vlan, len(found_hosts)) for vlan, found_hosts in context.virtual_vlans.items())
        )
    ), key=lambda x: (x[0], x[2], x[1]), reverse=True), headers=['type', 'vlan', 'found hosts']))
    print('')


def print_vrf_report(context):
    print("VRF report (where skynet exists):")
    print(tabulate(sorted((
        [tp, vrf, count] for tp, vrf, count in itertools.chain(
            (("bare", vrf, len(found_hosts)) for vrf, found_hosts in context.bare_vrfs.items()),
            (("virtual", vrf, len(found_hosts)) for vrf, found_hosts in context.virtual_vrfs.items())
        )
    ), key=lambda x: (x[0], x[2], x[1]), reverse=True), headers=['type', 'vrf', 'found hosts']))
    print('')


def print_network_report(context):
    print("Network report (maybe innacurate because of cached IP):")
    print(tabulate((
        [vrf, vlan, network, macro] for vrf, vlan, network, macro in (
            (info.vrf or '', info.vlan or '', info.network or '', info.macro or '')
            for info in sorted(context.used_networks)
        )
    ), headers=['vrf', 'vlan', 'network', 'macro']))
    print('')

    actual_macros = {info.macro for info in context.used_networks if info.macro is not None}
    print("Macro report (maybe innacurate because of cached IP):")
    print(tabulate(sorted((macro,) for macro in actual_macros), headers=['macro']))


def main():
    context = Context()
    context.collect()

    print_host_report(context)
    print_vlan_report(context)
    print_vrf_report(context)
    print_network_report(context)


if __name__ == '__main__':
    gc.disable()
    main()
