"""Pull connectivity data from netmon."""

import itertools
import json
import logging
from collections import defaultdict

import attr
import xxhash
from gevent.event import Event
from gevent.pool import Group, Pool

from sepelib.core import config
from sepelib.core.exceptions import Error
from walle.clients import netmon
from walle.expert.constants import NETMON_POLLING_PERIOD, NETMON_CONNECTIVITY_LEVEL, NETMON_CONNECTIVITY_STALE_TIMEOUT
from walle.expert.juggler import update_check_status_mtime, get_effective_timestamp, get_stale_timestamp
from walle.expert.types import CheckType, CheckStatus
from walle.host_health import HealthCheck
from walle.hosts import Host, HostState, HostStatus, HostLocation
from walle.models import timestamp
from walle.stats import stats_manager as stats, Timing, DISTRIBUTED
from walle.util.apscheduler import JitterIntervalTrigger
from walle.util.gevent_tools import gevent_idle_iter, gevent_idle_generator
from walle.util.misc import StopWatch, drop_none, add_interval_job, iter_chunks
from walle.util.mongo import MongoDocument, MongoPartitionerService

log = logging.getLogger(__name__)

# Connectivity level: 50% loss, 30% loss, 10% loss, 0% loss.
# Value is share of probes that have this (or less) packets lost.
# E.g. Switch have 10 hosts, 8 have icmp (ping) probes with 0% packets lost 1 have 10% loss, and 1 have 40% loss.
# Switch like that will have _CONN_50 = 1, _CONN_70 = 0.9, _CONN_90 = 0.8, _CONN_100 = 0.8
# Physical meaning is strictness, _CONN_100 is the most strict connectivity level.

_LEVEL_SWITCH = "switch"
_LEVEL_QUEUE = "queue"
_LEVEL_DATACENTER = "datacenter"

_GET_OLD_CHECKS_CHUNK_SIZE = 50


class NetmonInconsistentDataError(Error):
    pass


class NetmonDataStaledException(Error):
    pass


class ConnectivityDataStaledException(NetmonDataStaledException):
    def __init__(self, data_age):
        super().__init__("Wall-E has received staled ({}sec) connectivity data from netmon", data_age)


class SeenHostsDataStaledException(NetmonDataStaledException):
    def __init__(self, data_age):
        super().__init__("Wall-E has received staled ({}sec) seen hosts data from netmon", data_age)


def _check_default_threshold_exist(instance, attribute, value):
    if "default" not in value:
        raise ValueError("Attribute {} must be has key 'default'".format(attribute))


@attr.s
class SeenHostsThreshold:
    queue: dict[str, float] = attr.ib(validator=_check_default_threshold_exist)
    datacenter: dict[str, float] = attr.ib(validator=_check_default_threshold_exist)


@attr.s
class ConnectivityThreshold:
    switch: float = attr.ib()
    queue: float = attr.ib()
    datacenter: float = attr.ib()


@attr.s
class NetmonConfig:
    service: str = attr.ib()
    host: str = attr.ib()

    expression: str = attr.ib()
    network: str = attr.ib()
    protocol: str = attr.ib()

    seen_hosts_threshold: SeenHostsThreshold = attr.ib()
    total_hosts_cut_off: int = attr.ib()
    alive_threshold: ConnectivityThreshold = attr.ib()
    alive_connectivity_index: int = attr.ib()


def _read_common_params(config_part):
    return dict(
        expression=config_part["expression"],
        network=config_part["network"],
        protocol=config_part["protocol"],
        total_hosts_cut_off=int(config_part["seen_hosts"]["cutoff"]),
        alive_connectivity_index=int(config_part["alive"]["connectivity_index"]),
        alive_threshold=ConnectivityThreshold(
            switch=float(config_part["alive"]["threshold"]["switch"]),
            queue=float(config_part["alive"]["threshold"]["queue"]),
            datacenter=float(config_part["alive"]["threshold"]["datacenter"]),
        ),
        seen_hosts_threshold=SeenHostsThreshold(
            queue=config_part["seen_hosts"]["threshold"]["queue"],
            datacenter=config_part["seen_hosts"]["threshold"]["datacenter"],
        ),
    )


def make_netmon_config() -> NetmonConfig:
    return NetmonConfig(
        service="netmon",
        host=config.get_value("switch_connectivity.netmon.host"),
        **_read_common_params(config.get_value("switch_connectivity.netmon")),
    )


def make_nocsla_configs() -> list[NetmonConfig]:
    configs = []
    for nocsla_type in ["nocsla", "nocsla-cloud"]:
        if nocsla_type_config := config.get_value(f"switch_connectivity.{nocsla_type}", None):
            common_params = _read_common_params(nocsla_type_config)
            for location, params in config.get_value(f"switch_connectivity.{nocsla_type}.endpoints").items():
                configs.append(NetmonConfig(service=f"{nocsla_type}-{location}", host=params["host"], **common_params))

    return configs


class Netmon:
    def __init__(self) -> None:
        self.__shards_num = config.get_value("switch_connectivity.shards_num")
        self.__partitioner = MongoPartitionerService("cron-netmon")
        self.__pool = Pool()
        self.__stopped_event = Event()

    def start(self, scheduler) -> None:
        self.__partitioner.start()

        def run_poll(conf: NetmonConfig, current_shard, shards_num):
            poller = NetmonPoller(conf, current_shard, shards_num)
            poller.poll()

        def _make_job(job_shard_id):
            def _job():
                shard = self.__partitioner.get_shard(job_shard_id)
                if not shard:
                    log.debug("Not fetching connectivity data from #%s shard: not my shard.", job_shard_id)
                    return

                try:
                    pool = Group()
                    nocsla_configs = make_nocsla_configs()
                    with shard.lock:
                        for nocsla_config in nocsla_configs:
                            pool.spawn(run_poll, nocsla_config, job_shard_id, self.__shards_num)
                        pool.join(raise_error=True)
                except Exception as e:
                    # TODO: send event? Or use error booster only...
                    log.exception("Failed to update switch connectivity from nocsla: %s", e)

                    try:
                        netmon_config = make_netmon_config()
                        with shard.lock:
                            run_poll(netmon_config, job_shard_id, self.__shards_num)
                    except Exception as e:
                        log.exception("Failed to update switch connectivity from netmon: %s", e)

            def _spawn_job():
                if self.__stopped_event.is_set():
                    return
                self.__pool.spawn(_job)

            return _spawn_job

        for shard_id in range(self.__shards_num):
            add_interval_job(
                scheduler,
                _make_job(shard_id),
                name="Netmon poller shard #{}".format(shard_id),
                interval=NETMON_POLLING_PERIOD,
                trigger=JitterIntervalTrigger(seconds=NETMON_POLLING_PERIOD, jitter=20),
            )

    def stop(self):
        self.__stopped_event.set()
        self.__pool.kill()
        self.__partitioner.stop()


class NetmonPoller:
    def __init__(self, conf: NetmonConfig, current_shard, shards_num):
        self.config = conf
        self.current_shard = current_shard
        self.shards_num = shards_num
        self.client = netmon.NetmonClient(
            self.config.service, self.config.host, self.config.expression, self.config.network, self.config.protocol
        )

    def poll(self):
        log.debug(f"Fetching switch connectivity data for shard {self.current_shard}...")
        timing = Timing(self.config.service, StopWatch())

        pool = Group()
        fetched_data = {}

        @pool.spawn
        @timing.measure("fetch-alive-data")
        def _get_alive_metrics():
            fetched_data["alive_metrics"] = self.client.get_alive_metrics(NETMON_CONNECTIVITY_LEVEL)

        @pool.spawn
        @timing.measure("fetch-seen-hosts")
        def _get_seen_hosts():
            fetched_data["seen_hosts_data"] = self.client.get_seen_hosts()

        pool.join(raise_error=True)
        alive_metrics = fetched_data["alive_metrics"]
        seen_hosts_data = fetched_data["seen_hosts_data"]

        netmon_switch_checks = self._get_netmon_switch_checks(alive_metrics, seen_hosts_data)

        with timing.measure("fetch-hosts"):
            host_names, hosts_by_switch = _get_hosts(list(netmon_switch_checks))

        old_checks = {}
        with timing.measure("fetch-old-checks"):
            for hosts_batch in iter_chunks(host_names, _GET_OLD_CHECKS_CHUNK_SIZE):
                old_checks.update(_get_old_checks(hosts_batch))

        with timing.measure("save-checks"):
            _save_netmon_data(netmon_switch_checks, alive_metrics["generated"], hosts_by_switch, old_checks)

        timing.split("db_store_time")
        timing.reset("collection_time")
        log.debug(f"Connectivity data fetched and stored for shard {self.current_shard}.")

    def _get_netmon_switch_checks(self, alive_metrics, seen_hosts_data):
        self._check_data_age(alive_metrics, seen_hosts_data)
        shard_switches = self._filter_switches(alive_metrics)
        checks_generator = self.netmon_checks_generator(shard_switches, alive_metrics, seen_hosts_data)
        return dict(self._count_switch_stats(checks_generator))

    def _check_data_age(self, connectivity_data, seen_hosts_data):
        current_timestamp = timestamp()

        connectivity_age = current_timestamp - connectivity_data["generated"]
        volume_data_age = current_timestamp - seen_hosts_data["timestamp"]

        stats.add_sample((self.config.service, "connectivity_age"), connectivity_age)
        stats.add_sample((self.config.service, "volume_data_age"), volume_data_age)

        try:
            if connectivity_age > NETMON_CONNECTIVITY_STALE_TIMEOUT:
                raise ConnectivityDataStaledException(connectivity_age)

            if volume_data_age > NETMON_CONNECTIVITY_STALE_TIMEOUT:
                raise SeenHostsDataStaledException(volume_data_age)
        except NetmonDataStaledException:
            for status in CheckStatus.ALL:
                stats.set_counter_value((self.config.service, "switch_status", status), 0, DISTRIBUTED)
            raise

    def _filter_switches(self, alive_data):
        def _shard_for_switch_name(name, num=self.shards_num):
            return xxhash.xxh64_intdigest(name.encode() if name is not None else b'') % num

        for switch in alive_data["switches"]:
            if _shard_for_switch_name(switch["name"]) == self.current_shard:
                yield switch

    @gevent_idle_generator()
    def _count_switch_stats(self, netmon_checks_generator):
        counts = defaultdict(int)
        for switch_name, check in netmon_checks_generator:
            yield switch_name, check

            counts[check["status"]] += 1

        for status, value in counts.items():
            stats.set_counter_value((self.config.service, "switch_status", status), value, DISTRIBUTED)

    def netmon_checks_generator(self, shard_switches, connectivity_data, volume_data):
        check_generator = make_check_generator(self.config, connectivity_data, volume_data)

        for switch in shard_switches:
            yield switch["name"], check_generator.make_check(switch)


def _save_netmon_data(netmon_switch_checks, connectivity_timestamp, hosts_by_switch, old_checks):
    with HealthCheck.bulk_update() as bulk_savers:
        for switch, switch_check in netmon_switch_checks.items():
            _save_check(bulk_savers, switch_check, hosts_by_switch.pop(switch, []), old_checks, connectivity_timestamp)


def _save_check(bulk, switch_check, hosts, old_checks, connectivity_timestamp):
    for host in hosts:
        prev_state = old_checks.pop(host.name, None)
        host_check = switch_check.copy()
        if prev_state:
            update_check_status_mtime(host_check, prev_state)
        host_check["fqdn"] = host.name

        changed = (
            not prev_state
            or prev_state["status"] != host_check["status"]
            or prev_state["status_mtime"] != host_check["status_mtime"]
        )
        bulk.save_checks(host.name, [(host_check, changed)], connectivity_timestamp)


def _get_old_checks(host_names):
    check_collection = HealthCheck.get_collection()

    checks_cursor = check_collection.find(
        {
            HealthCheck.type.db_field: CheckType.NETMON,
            HealthCheck.fqdn.db_field: {"$in": host_names},
        },
        (
            HealthCheck.fqdn.db_field,
            HealthCheck.type.db_field,
            HealthCheck.status.db_field,
            HealthCheck.status_mtime.db_field,
            HealthCheck.timestamp.db_field,
        ),
    )

    old_checks = {check["fqdn"]: check for check in gevent_idle_iter(checks_cursor)}

    return old_checks


def _get_hosts(switch_names):
    host_collection = Host.get_collection()
    host_document = MongoDocument.for_model(Host)
    switch_field = Host.location.db_field + "." + HostLocation.switch.db_field

    query = {
        Host.state.db_field: {"$in": HostState.ALL_ASSIGNED},
        Host.status.db_field: {"$ne": HostStatus.INVALID},
    }

    if switch_names:
        query[switch_field] = {"$in": switch_names}
    else:
        query[switch_field] = {"exists": True}

    host_docs = [
        host_document(h)
        for h in gevent_idle_iter(
            host_collection.find(
                query,
                (Host.name.db_field, Host.project.db_field, switch_field),
            )
        )
    ]

    def groupby_key(h):
        return h.location.switch

    host_docs.sort(key=groupby_key)

    hosts_by_switch = {}
    host_names = []

    for switch, hosts in itertools.groupby(host_docs, key=groupby_key):
        hosts_by_switch[switch] = list(hosts)
        host_names.extend(host.name for host in hosts_by_switch[switch] if host.name)

    return host_names, hosts_by_switch


class MetadataIndex:
    def __init__(self, converter, data):
        """
        Provide access to netmon data converted into wall-e check metadata.

        :param converter: converter that can transform netmon data into wall-e check data
        :param data: list of netmon data items directly from netmon, e.g. alive_data["switches"]
        :type converter: MetadataConverter
        :type data: list[dict]
        """
        self.converter = converter
        self.data = self._index_metadata(data)

    def get(self, name):
        if name not in self.data:
            # Not cached means no data have been provided initially.
            # Converter may still produce some metadata for missing data.
            self.data[name] = self.converter.get_metadata(name, None)

        return self.data[name]

    def _index_metadata(self, data):
        _convert = self.converter.get_metadata
        return {node["name"]: _convert(node["name"], node) for node in data}


class MetadataConverter:
    """Provide interface to use in MetadataIndex for converters."""

    def get_metadata(self, name, node):
        raise NotImplementedError


class ConnectivityMetadata(MetadataConverter):
    def __init__(self, level, connectivity_index, connectivity_threshold):
        self.level = level

        self.connectivity_index = connectivity_index
        self.connectivity_threshold = connectivity_threshold

    def get_metadata(self, name, node):
        if node is None:
            raise NetmonInconsistentDataError("Connectivity data for {} {} is missing.".format(self.level, name))

        connectivity = self._get_level_connectivity(node)
        status, reason = self._get_connectivity_status(name, connectivity)

        return drop_none(
            {
                "status": status,
                "reason": reason,
                "alive": node["alive"],
                "connectivity": connectivity,
            }
        )

    def _get_level_connectivity(self, data):
        if data["connectivity"]:
            return data["connectivity"][self.connectivity_index]
        else:
            return None

    def _get_connectivity_status(self, name, connectivity):
        if connectivity is None:
            return CheckStatus.MISSING, self.missing_data_reason(name)

        if connectivity >= self.connectivity_threshold:
            return CheckStatus.PASSED, self.status_reason(name, connectivity, "above")
        else:
            return CheckStatus.FAILED, self.status_reason(name, connectivity, "below")

    def missing_data_reason(self, name):
        return "Connectivity data is missing for {} {}.".format(self.level, name)

    def status_reason(self, name, connectivity, result):
        return "Connectivity for {} {} ({:0.3f}) is {} threshold {}.".format(
            self.level, name, connectivity, result, self.connectivity_threshold
        )


class VolumeMetadata(MetadataConverter):
    def __init__(self, level, seen_hosts_threshold, total_hosts_cutoff):
        self.level = level
        self.seen_hosts_threshold = seen_hosts_threshold
        self.total_hosts_cut_off = total_hosts_cutoff

    def get_metadata(self, name, node):
        volume_status, volume_reason = self._get_volume_status(name, node)

        return {
            "status": volume_status,
            "reason": volume_reason,
        }

    def _get_volume_status(self, name, seen_hosts):
        if seen_hosts is None:
            return CheckStatus.MISSING, self.missing_data_reason(name)

        seen = seen_hosts["seen"]
        total = seen_hosts["total"]

        if total < self.total_hosts_cut_off:
            return CheckStatus.PASSED, self.few_total_hosts_reason(name, total)

        seen_percent = (seen * 100.0) / total
        threshold = self._get_seen_hosts_threshold(name)

        if seen_percent >= threshold:
            return CheckStatus.PASSED, self.status_reason(name, seen, total, seen_percent, result="above")
        else:
            return CheckStatus.SUSPECTED, self.status_reason(name, seen, total, seen_percent, result="below")

    def _get_seen_hosts_threshold(self, name):
        default_threshold = self.seen_hosts_threshold["default"]
        return self.seen_hosts_threshold.get(name, default_threshold)

    def missing_data_reason(self, name):
        return "No seen hosts for {} {}.".format(self.level, name)

    def few_total_hosts_reason(self, name, total):
        return "Too few total hosts for {} {} ({}), always consider Ok.".format(self.level, name, total)

    def status_reason(self, name, seen, total, seen_percent, result):
        return (
            "{percent:0.1f}% of active hosts for {level} {name} ({seen} from {total}) which is {result}"
            " threshold {threshold}.".format(
                level=self.level,
                name=name,
                percent=seen_percent,
                seen=seen,
                total=total,
                threshold=self._get_seen_hosts_threshold(name),
                result=result,
            )
        )


class NetmonCheckGenerator:
    def __init__(
        self,
        queue_connectivity_metadata,
        datacenter_connectivity_metadata,
        queue_volume_metadata,
        datacenter_volume_metadata,
        switch_metadata_converter,
        data_timestamp,
    ):
        """
        Produce wall-e checks from given metadata indices.

        :type queue_connectivity_metadata: MetadataIndex
        :type datacenter_connectivity_metadata: MetadataIndex
        :type queue_volume_metadata: MetadataIndex
        :type datacenter_volume_metadata: MetadataIndex
        :type switch_metadata_converter: MetadataConverter
        :type data_timestamp: int
        """

        self.queue_connectivity_metadata = queue_connectivity_metadata
        self.queue_volume_metadata = queue_volume_metadata

        self.datacenter_connectivity_metadata = datacenter_connectivity_metadata
        self.datacenter_volume_metadata = datacenter_volume_metadata
        self.switch_converter = switch_metadata_converter
        self.data_timestamp = data_timestamp

    def make_check(self, switch):
        metadata = self.get_metadata(switch)
        status = self.get_check_status(metadata)

        metadata["timestamp"] = self.data_timestamp
        return {
            "type": CheckType.NETMON,
            "status": status,
            "status_mtime": self.data_timestamp,
            "timestamp": self.data_timestamp,
            "metadata": json.dumps(metadata),
            "effective_timestamp": get_effective_timestamp(
                CheckType.NETMON,
                status,
                status_mtime=self.data_timestamp,
                receive_timestamp=self.data_timestamp,
                metadata=metadata,
            ),
            "stale_timestamp": get_stale_timestamp(
                CheckType.NETMON,
                status,
                receive_timestamp=self.data_timestamp,
                metadata=metadata,
            ),
        }

    @staticmethod
    def get_check_status(metadata):
        level_statuses = {level["status"] for level in metadata.values()}

        # check statuses in order of importance
        for status in (CheckStatus.FAILED, CheckStatus.SUSPECTED, CheckStatus.MISSING):
            if status in level_statuses:
                return status

        return metadata[_LEVEL_SWITCH]["status"]

    def get_metadata(self, switch):
        return {
            _LEVEL_SWITCH: self._get_switch_metadata(switch),
            _LEVEL_QUEUE: self._get_queue_metadata(switch["queue"]),
            _LEVEL_DATACENTER: self._get_datacenter_metadata(switch["dc"]),
        }

    def _get_switch_metadata(self, switch):
        return self._get_level_metadata(self.switch_converter.get_metadata(switch["name"], switch))

    def _get_queue_metadata(self, name):
        return self._get_level_metadata(
            self.queue_connectivity_metadata.get(name), self.queue_volume_metadata.get(name)
        )

    def _get_datacenter_metadata(self, name):
        return self._get_level_metadata(
            self.datacenter_connectivity_metadata.get(name), self.datacenter_volume_metadata.get(name)
        )

    @staticmethod
    def _get_level_metadata(connectivity_metadata, volume_metadata=None):
        status = connectivity_metadata["status"]
        if volume_metadata and status == CheckStatus.PASSED:
            status = volume_metadata["status"]

        metadata = {
            "status": status,
            "connectivity": connectivity_metadata,
        }
        if volume_metadata:
            metadata["volume"] = volume_metadata

        return metadata


def netmon_checks_generator(shard_switches, connectivity_data, volume_data):
    check_generator = make_check_generator(connectivity_data, volume_data)

    for switch in shard_switches:
        yield switch["name"], check_generator.make_check(switch)


def make_check_generator(conf: NetmonConfig, connectivity_data, volume_data):
    conn_queues = MetadataIndex(
        ConnectivityMetadata(_LEVEL_QUEUE, conf.alive_connectivity_index, conf.alive_threshold.queue),
        connectivity_data["queues"],
    )
    conn_datacenters = MetadataIndex(
        ConnectivityMetadata(_LEVEL_DATACENTER, conf.alive_connectivity_index, conf.alive_threshold.datacenter),
        connectivity_data["datacenters"],
    )

    volume_queues = MetadataIndex(
        VolumeMetadata(_LEVEL_QUEUE, conf.seen_hosts_threshold.queue, conf.total_hosts_cut_off), volume_data["queue"]
    )
    volume_datacenters = MetadataIndex(
        VolumeMetadata(_LEVEL_DATACENTER, conf.seen_hosts_threshold.datacenter, conf.total_hosts_cut_off),
        volume_data["dc"],
    )

    switch_metadata = ConnectivityMetadata(_LEVEL_SWITCH, conf.alive_connectivity_index, conf.alive_threshold.switch)

    return NetmonCheckGenerator(
        queue_connectivity_metadata=conn_queues,
        datacenter_connectivity_metadata=conn_datacenters,
        queue_volume_metadata=volume_queues,
        datacenter_volume_metadata=volume_datacenters,
        switch_metadata_converter=switch_metadata,
        data_timestamp=connectivity_data["generated"],
    )
