#!/usr/bin/env python

"""Checks correctness of our network configuration detection algorithm."""

from __future__ import print_function
from __future__ import unicode_literals

import errno
import functools
import itertools
import json
import time

import xmltodict

from sepelib.core import config

from walle.clients import bot, inventory, racktables, yasubr
from walle.clients.racktables import RacktablesError, PersistentRacktablesError
from walle.errors import InvalidHostNameError, InvalidHostConfiguration, NoInformationError
from walle.util import net

config.load("bin/walle.conf.yaml")


def cached(cache_path):
    def decorator(func):
        @functools.wraps(func)
        def decorated(*args, **kwargs):
            result = load_cache(cache_path)

            if result is None:
                result = func(*args, **kwargs)
                store_cache(cache_path, result)

            return result

        return decorated

    return decorator


def load_cache(cache_path):
    try:
        with open(cache_path) as cache_file:
            return json.load(cache_file)
    except EnvironmentError as e:
        if e.errno != errno.ENOENT:
            raise


def store_cache(cache_path, value):
    with open(cache_path, "w") as cache_file:
        json.dump(value, cache_file)


def get_switches():
    net_layout = xmltodict.parse(racktables.request("/export/net-layout.xml", limited=False))
    switches = net_layout["layout"]["switches"]["switch"]
    return {switch["@name"] for switch in switches if switch["@name"]}


def get_port_statuses(switch):
    tries = 3

    while True:
        try:
            return racktables.json_request("/export/get-vlanconfig-by-port.php", params={"switch": switch}, limited=True)
        except PersistentRacktablesError as e:
            print(e)
            error = unicode(e)
            if "cant find switch by name" in error or "is not vlan-enabled" in error:
                return
            raise
        except RacktablesError as e:
            print(e, "Retrying...")
            if not tries:
                raise
            time.sleep(3)
            tries -= 1


@cached("netmap.json")
def get_netmap():
    print("Updating netmap...")

    netmap = {}
    switches = get_switches()

    for switch_id, switch in enumerate(switches):
        print("> {} ({}/{})...".format(switch, switch_id + 1, len(switches)))

        switch_status = netmap.setdefault(switch, {})

        while True:
            try:
                port_statuses = get_port_statuses(switch)
            except RacktablesError as e:
                # RackTables just resets connections sometimes in various ways
                if "Connection aborted" in e or "Server returned an error: 429" in e:
                    print(e, "Sleeping and retry...")
                    time.sleep(10)
                else:
                    raise
            else:
                break

        if port_statuses is None:
            continue

        for port_status in port_statuses:
            # TODO: A temporary hackaround for https://st.yandex-team.ru/NOCDEV-213
            if port_status["config"] == "A":
                switch_status[port_status["port"]] = []
            else:
                switch_status[port_status["port"]] = list(racktables._parse_vlan_specification(port_status["config"]))

    print("Netmap has been updated.")

    return netmap


@cached("network_locations.json")
def get_network_location_map():
    network_location_map = {}

    for host_info in bot.get_host_location_info():
        host = host_info.name
        if host is None:
            continue

        try:
            network_location_map[host] = bot.get_rt_location(host_info.location)
        except NoInformationError:
            pass

    return network_location_map


VLAN_NETWORKS = None
def get_vlan_networks(switch, vlan):
    cache_path = "vlan_networks.json"

    global VLAN_NETWORKS
    if VLAN_NETWORKS is None:
        VLAN_NETWORKS = load_cache(cache_path) or {}

    try:
        return VLAN_NETWORKS[switch][vlan]
    except KeyError:
        pass

    networks = racktables.get_vlan_networks(switch, vlan)
    VLAN_NETWORKS.setdefault(switch, {})[vlan] = networks
    store_cache(cache_path, VLAN_NETWORKS)

    return networks


def check_host_fqdns(host, v6, switch, native_vlan, vlans):
    assert len(v6) == 1
    bb_ip = list(v6)[0]

    try:
        fbv4, fbv6 = net.get_host_ips("fb-" + host)
    except InvalidHostNameError:
        print("{}: no fasbone FQDN.".format(host))
        return False

    assert not fbv4 and len(fbv6) == 1
    fb_ip = list(fbv6)[0]

    try:
        bb_network, bb_mac = net.split_eui_64_address(bb_ip)
        fb_network, fb_mac = net.split_eui_64_address(fb_ip)
    except net.InvalidEui64AddressError as e:
        print("{}: {}".format(host, e))
        return False

    if bb_mac != fb_mac:
        print("{}: backbone/fastbone MAC addresses aren't equal: {} {}.".format(host, bb_ip, fb_ip))
        return False

    bb_vlan = native_vlan
    fb_vlans = set(vlans) - {native_vlan}
    if len(fb_vlans) != 1:
        print("{}: got invalid number of fasbone VLANs: {}.", host, ",".join(fb_vlans))
        return False
    fb_vlan = list(fb_vlans)[0]

    expected_bb_networks = get_vlan_networks(switch, bb_vlan)
    expected_fb_networks = get_vlan_networks(switch, fb_vlan)
    if bb_network not in expected_bb_networks or fb_network not in expected_fb_networks:
        print("{}: is in invalid networks: BB {} (not in {}), FB {} (not in {}), VLANs: {} {} of {}.".format(
              host, bb_network, expected_bb_networks, fb_network, expected_fb_networks,
              bb_vlan, fb_vlan, switch))
        return False

    return True


def check_network_configuration(check_vlans, check_dns, show_progress):
    with open("hosts.list") as hosts_file:
        hosts = hosts_file.read().strip().split("\n")

    with open("vlan_map.json") as vlan_map_file:
        active_vlans_map = json.load(vlan_map_file)

    network_location_map = get_network_location_map()
    l3_segments_map = racktables.get_l3_segments_map()

    errors = 0
    checked = 0
    processed = 0
    netmap = get_netmap()

    for host in hosts:
        try:
            v4, v6 = net.get_host_ips(host)
        except InvalidHostNameError:
            if not host.endswith(".yandex.ru"):
                continue

            host = host[:-len(".yandex.ru")] + ".search.yandex.net"

            try:
                v4, v6 = net.get_host_ips(host)
            except InvalidHostNameError:
                continue

        processed += 1

        if v4:
            assert len(v4) == 1

            try:
                expected_native_vlan, expected_fb_vlan = yasubr.get_host_vlans(list(v4)[0])
            except InvalidHostConfiguration as e:
                print("{}: {}".format(host, e))
                continue

            expected_vlans = sorted({expected_native_vlan, expected_fb_vlan})
        else:
            try:
                network_location = network_location_map[host]
            except KeyError:
                print("{}: Can't determine network location.".format(host))
                continue

            expected_native_vlan = 604
            l3_segments = l3_segments_map[(network_location, expected_native_vlan)]
            expected_vlans = [expected_native_vlan] + sorted(set(itertools.chain.from_iterable(
                l3_segment.fb_vlans for l3_segment in l3_segments)))

        switch_info = inventory.get_host_switch_from_network_map(racktables.get_network_map(), host)
        if switch_info is None:
            continue

        ok = True
        checked += 1

        if check_dns and not v4:
            ok &= check_host_fqdns(host, v6, switch_info.switch, expected_native_vlan, expected_vlans)

        if check_vlans:
            vlans, native_vlan = netmap[switch_info.switch][switch_info.port]
            if host in active_vlans_map:
                active_vlans = active_vlans_map[host]
                vlans = sorted({native_vlan} | set(active_vlans))

            if (native_vlan, vlans) != (expected_native_vlan, expected_vlans):
                print("{}: {} instead of {}".format(host, (expected_native_vlan, expected_vlans), (native_vlan, vlans)))
                ok = False

        if not ok:
            errors += 1

        if show_progress and processed % 100 == 0:
            print("{} failed from {} processed of {}.".format(errors, checked, processed))


if __name__ == "__main__":
    check_network_configuration(check_vlans=False, check_dns=True, show_progress=True)
