import datetime
import email.utils
import operator
import time
from collections import namedtuple

import cachetools
import requests
import six
from requests import RequestException

import walle.clients.utils
import walle.hosts
from sepelib.core import constants
from walle.clients import bot
from walle.clients.network.racktables_client import RacktablesClient
from walle.clients.racktables import (
    log,
    _NETMAP_INACCURACY,
    RacktablesError,
    RacktablesSwitchInfo,
    _handle_error,
    _url_to_racktables,
)
from walle.util.gevent_tools import gevent_idle_iter
from walle.util.misc import format_long_list_for_logging, timetuple_to_time
from walle.util.net import mac_to_int, mac_from_int


class Netmap:
    _mac_to_switch_cache = cachetools.TTLCache(maxsize=1, ttl=30 * constants.MINUTE_SECONDS)

    def __init__(self, several_locations_error_threshold, max_error_hosts_in_log=50):
        self.__several_locations_error_threshold = several_locations_error_threshold
        self.__max_error_hosts_in_log = max_error_hosts_in_log

        self.__suspicious_ports = set()
        self.__several_locations = set()

        self.__mac_to_inv = self._get_mac_to_inv_mapping()
        self.__mac_to_switch, self._switch_to_mac = self._get_mac_switch_mappings()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if len(self.__several_locations) >= self.__several_locations_error_threshold:
            log.error(
                "Got several network locations for %s hosts: %s",
                len(self.__several_locations),
                format_long_list_for_logging(self.__several_locations, limit=self.__max_error_hosts_in_log),
            )

    def get_host_switch(self, host_inv, host_macs, host_project, host_name=None):
        host_macs = set(host_macs)

        locations = {}

        for mac in host_macs:
            location = self.__mac_to_switch.get(mac_to_int(mac))
            if location is None:
                continue

            location_id = (location.switch, location.port)
            if location_id not in locations or locations[location_id].timestamp < location.timestamp:
                locations[location_id] = location

        if not locations:
            return

        locations = list(locations.values())

        if len(locations) > 1:
            # Try to filter out old host locations
            locations.sort(key=operator.attrgetter("timestamp"))
            min_timestamp = locations[-1].timestamp - 2 * _NETMAP_INACCURACY
            locations = list(filter(lambda location: location.timestamp >= min_timestamp, locations))

        # Sometimes NOC fail to unwrap all branches and leafs of actual network topology and generate netmap where a lot
        # of hosts are placed behind an aggregator switch instead of their actual switches. Try to detect the situation
        # here.
        for location in locations:
            port_macs = {mac_from_int(mac) for mac in self._switch_to_mac[(location.switch, location.port)]}
            extra_macs = port_macs - host_macs

            if not extra_macs:
                # Our host is the only citizen on the port - all is OK.
                continue

            port_host_invs = {host_inv}
            port_host_macs = [mac_from_int(location.int_mac)]

            port_ipmi_invs = set()
            port_ipmi_macs = []

            for extra_mac in extra_macs:
                try:
                    extra_inv, ipmi_mac = self.__mac_to_inv[extra_mac]
                except KeyError:
                    # Filter out virtual machines and MACs that aren't used already
                    continue

                if ipmi_mac:
                    port_ipmi_invs.add(extra_inv)
                    port_ipmi_macs.append(extra_mac)
                else:
                    port_host_invs.add(extra_inv)
                    port_host_macs.append(extra_mac)

            # For the case when host's IPMI MAC == Network MAC. It's a very rare configuration which shouldn't be used
            # and we create tickets for such hosts to fix it, but here we should workaround it.
            port_ipmi_invs.discard(host_inv)

            # Count number of nodes behind the port. Service processors (IPMI) in most cases are connected to a simple
            # switch (not commutator) which is connected to a single smart switch (commutator) port. That is why we
            # count all service processors as one physical node here.
            nodes_on_port = len(port_host_invs) + min(len(port_ipmi_invs), 1)

            if nodes_on_port == 1:
                # Our host is the only citizen on the port - all is OK.
                pass
            elif nodes_on_port == 2:
                # Assume that it's a side effect from body change or moving hosts between switches
                pass
            else:
                switch_port = (location.switch, location.port)
                if switch_port not in self.__suspicious_ports:
                    self.__suspicious_ports.add(switch_port)

                    log.error(
                        "Got a suspicious netmap data for %s/%s port: "
                        "it has %s (%s) nodes with (%s) MACs and %s (%s) IPMI nodes with (%s) MACs.",
                        location.switch,
                        location.port,
                        len(port_host_invs),
                        format_long_list_for_logging(
                            [str(inv) for inv in sorted(port_host_invs)], limit=self.__max_error_hosts_in_log
                        ),
                        format_long_list_for_logging(
                            [mac for mac in sorted(port_host_macs)], limit=self.__max_error_hosts_in_log * 2
                        ),
                        len(port_ipmi_invs),
                        format_long_list_for_logging(
                            [str(inv) for inv in sorted(port_ipmi_invs)], limit=self.__max_error_hosts_in_log
                        ),
                        format_long_list_for_logging(
                            [mac for mac in sorted(port_ipmi_macs)], limit=self.__max_error_hosts_in_log * 2
                        ),
                    )

                return

        if len(locations) == 1:
            return locations[0]

        if len(locations) > 1:
            mac_addresses = {mac_from_int(loc.int_mac) for loc in locations}
            if len(mac_addresses) == len(locations) and mac_addresses.issubset(host_macs):
                # FIXME: host have several NICs connected to different switches/ports.
                # It is probably ok, but we don't really know what to do with them.
                # The current solution is to ignore location data, but don't complain.
                few_switches_is_ok = True
            else:
                few_switches_is_ok = False
                self.__several_locations.add(walle.hosts.get_host_human_name(host_inv, host_name))

            (log.debug if few_switches_is_ok else log.warning)(
                "%s: Got several network locations: %s.",
                walle.hosts.get_host_human_id(host_inv, host_name),
                ", ".join(map(repr, locations)),
            )

    @classmethod
    def _get_mac_switch_mappings(cls):
        mappings, actualization_time = cls._mac_to_switch_cache.get('data', (None, None))

        response, actualization_time = _get_netmap("L12_walle.fvt", if_modified_since=actualization_time)
        if response is not None:
            log.info("Updating RackTables MAC to Switch/Port mapping cache...")
            mappings = _parse_mac_to_switch_mappings(response)
            log.info(
                "RackTables MAC to Switch/Port mapping cache has been successfully updated (%s items).",
                len(mappings[0]),
            )

        cls._mac_to_switch_cache['data'] = (mappings, actualization_time)

        return mappings

    @staticmethod
    def _get_mac_to_inv_mapping():
        mac_to_inv = {}

        log.info("Collecting `MAC -> Inventory number` mapping from BOT...")

        for host in gevent_idle_iter(bot.iter_hosts_info()):
            # Always store IPMI MAC first to override it by network MAC when host's IPMI MAC == Network MAC.
            # It's a very rare configuration which shouldn't be used and we create tickets for such hosts to fix it,
            # but here we should workaround it.
            if "ipmi_mac" in host:
                mac_to_inv[host["ipmi_mac"]] = (host["inv"], True)

            for mac in host["macs"]:
                mac_to_inv[mac] = (host["inv"], False)

        log.info("Collected info about %s MACs.", len(mac_to_inv))

        return mac_to_inv


def _parse_mac_to_switch_mappings(response):
    mac_to_switch = {}

    try:
        for row in _parse_racktables_data(response.iter_lines(), fields=["mac", "_", "switch_port", "timestamp"]):
            try:
                int_mac = mac_to_int(row.mac)
                switch, port = row.switch_port.split("/", 1)
                timestamp = int(row.timestamp) - _NETMAP_INACCURACY
            except Exception as e:
                log.error("Got an invalid data from RackTables (%r): %s", row.line, e)
                raise RacktablesError("Got an invalid data from RackTables.")

            if RacktablesClient().is_interconnect_switch(switch):
                continue

            old_switch = mac_to_switch.get(int_mac)
            if old_switch is not None and old_switch.timestamp > timestamp:
                continue

            mac_to_switch[int_mac] = RacktablesSwitchInfo(switch, port, int_mac, timestamp)
    except RequestException as e:
        _handle_error(e)

    switch_to_mac = {}

    for switch, port, int_mac, _timestamp in gevent_idle_iter(mac_to_switch.values()):
        switch_port, port_macs = (switch, port), (int_mac,)
        if switch_to_mac.setdefault(switch_port, port_macs) is not port_macs:
            switch_to_mac[switch_port] += port_macs

    return mac_to_switch, switch_to_mac


def _parse_racktables_data(data, fields=None, delimiter=" "):
    if not fields:
        raise ValueError("fields can't be empty")

    row_class_name = "RowClass_" + "".join(map(str.capitalize, fields))
    RowClass = namedtuple(row_class_name, ["line"] + fields, rename=True)  # a bit of cheating

    try:
        for line in gevent_idle_iter(data):
            line = six.ensure_str(line, "utf-8").strip()
            if not line:
                continue

            columns = line.split(delimiter)

            try:
                yield RowClass(line, *columns)
            except TypeError as e:
                if len(columns) != len(RowClass._fields):
                    log.error("Got an invalid number of columns from RackTables (%s), expected %s.", columns, fields)
                    raise RacktablesError("Got an invalid number of columns from RackTables.")
                else:
                    log.error("Got an invalid data from RackTables (%r): %s", line, e)
                    raise RacktablesError("Got an invalid data from RackTables.")
    except RequestException as e:
        raise _handle_error(e)


def _get_netmap(name, if_modified_since=None):
    """
    https://racktables.yandex.net/export/netmap/L123* предоставляет сагрегированную информацию со свитчей о машинках:
    * Если MAC-адрес засветился на каком-нибудь свитче, то он появится в выгрузке в течение 1-4 часов. Если MAC
      перестанет светиться, то из выгрузки он пропадет где-то через неделю.
    * В выгрузке один MAC ни при каких условиях не может числиться за разными портами: если есть информация с двух
      свитчей, что на них был MAC X, то в выгрузке будет только тот порт, на котором он засветился в последний раз.
    * В отличие от MAC-адресов, один и тот же IP-адрес может числиться за разными портами. Это может быть, если:
      * Поменяли тушку (IP теперь используется другим MACом).
      * Несколько машин имеют данный IP-адрес (балансировка).
    """
    url = _url_to_racktables("/export/netmap/" + name)

    headers = {}
    if if_modified_since is not None:
        headers["If-Modified-Since"] = datetime.datetime.utcfromtimestamp(if_modified_since).strftime(
            "%a, %d %b %Y %H:%M:%S GMT"
        )

    try:
        response = walle.clients.utils.request(
            "racktables", "GET", url, stream=True, headers=headers, check_status=False
        )

        if response.status_code == requests.codes.not_modified:
            return None, if_modified_since

        if response.status_code != requests.codes.ok:
            raise RequestException(response.reason)

        return response, _get_actualization_time(response)
    except RequestException as e:
        _handle_error(e)


def _get_actualization_time(response):
    try:
        last_modified = response.headers["Last-Modified"]
    except KeyError:
        log.error("RackTables response for %s doesn't contain Last-Modified header.", response.url)
        return

    try:
        last_modified_time = email.utils.parsedate(last_modified)
        if last_modified_time is None:
            raise ValueError

        actualization_time = timetuple_to_time(last_modified_time)
    except ValueError:
        log.error("Got an invalid Last-Modified header from RackTables for %s: %s.", response.url, last_modified)
        return

    data_age = int(time.time() - actualization_time)
    if data_age > 5 * constants.HOUR_SECONDS:
        log.error("Got too old data from RackTables (%s minutes).", data_age // constants.MINUTE_SECONDS)

    return actualization_time
