import collections
import contextlib
import concurrent.futures
import multiprocessing
import functools

from infra.rtc_sla_tentacles.backend.lib.api.utils import form_metric_name
from infra.rtc_sla_tentacles.backend.lib.api.handlers import features
from infra.rtc_sla_tentacles.backend.lib.harvesters_snapshots import manager as harvester_snapshots_manager
from infra.rtc_sla_tentacles.backend.lib.metrics import metrics_provider, checks as metrics_checks


class MetricStorage:

    FLUSH_LOCAL_METRICS_TO_SHARED_MEMORY_SIGNAL = 51

    def __init__(self, config_interface, mongo_client, logger):
        self.config_interface = config_interface
        self.env_name = self.config_interface.get_env_name()
        self.logger = logger
        manager = multiprocessing.Manager()
        self.shared_dict = manager.dict({
            "stat_metrics": {},
            "yasm_missing_nodes_metrics": [],
            "yasm_pods_manager_metrics": [],
            "yasm_allocation_zone_metrics": [],
        })

        self.shared_dict_lock = multiprocessing.Lock()
        self.stats_holder = collections.Counter()
        self.stats_holder[form_metric_name("total")] = 0
        self.harvester_snapshot_manager = harvester_snapshots_manager.HarvesterSnapshotManager(mongo_client)

    def stats_collector(self, metric):
        self.stats_holder[form_metric_name(metric, "errors")] = 0
        self.stats_holder[form_metric_name(metric, "total")] = 0

        def decorator(f):

            @functools.wraps(f)
            def wrapper(*args, **kwargs):
                self.stats_holder[form_metric_name(metric, "total")] += 1
                self.stats_holder[form_metric_name("total")] += 1
                try:
                    return f(*args, **kwargs)
                except Exception:
                    self.stats_holder[form_metric_name(metric, "errors")] += 1
                    raise

            return wrapper

        return decorator

    def _update_missing_nodes_yasm_cache(self, _worker_id):
        self.logger.info("Updating missing nodes cache")
        # TODO(rocco66): TENTACLES-327
        if missing_nodes_metrics := self.harvester_snapshot_manager.read_last_snapshot_data("missing_nodes_monitoring"):
            with self.shared_dict_lock:
                self.shared_dict["yasm_missing_nodes_metrics"] = missing_nodes_metrics["yasm_missing_nodes_metrics"]

    def _update_yp_lite_pods_manager_monitoring_yasm_cache(self, _worker_id):
        self.logger.info("Updating YP Lite pods manager monitoring cache")
        # TODO(rocco66): TENTACLES-327
        if pods_metrics := self.harvester_snapshot_manager.read_last_snapshot_data("yp_lite_pods_monitoring"):
            with self.shared_dict_lock:
                self.shared_dict["yasm_pods_manager_metrics"] = pods_metrics.get("yasm_pods_manager_metrics", [])

    def _save_allocation_zone_metrics_cache(self, all_zones_metrics):
        self.logger.info("Updating allocation zones metrics cache")
        yasm_prefix = "az"
        yasm_signals = []

        for allocation_zone_id, zone_metrics in all_zones_metrics.items():
            for signal_name, signal_value in zone_metrics.get_yasm_signals():
                yasm_signals.append([f"{yasm_prefix}__{allocation_zone_id}__{signal_name}", signal_value])

        # Save with lock as shared dictionary's _nested_ structures to propagate update to all uWSGI workers.
        with self.shared_dict_lock:
            self.shared_dict["yasm_allocation_zone_metrics"] = yasm_signals

    def _fetch_additional_daemonset_features(self, allocation_zone_id, all_zones_metrics):
        allocation_zone_metrics = all_zones_metrics[allocation_zone_id]
        try:
            unused_nodes_snapshot = self.harvester_snapshot_manager.read_last_snapshot_data(
                "yp_unused_nodes_monitoring", allocation_zone_id,
            )
        except harvester_snapshots_manager.SnapshotNotFoundError:
            self.logger.error(f"Can't find yp_unused_nodes_monitoring/{allocation_zone_id} snapshot")
        else:
            if unused_nodes_snapshot:
                allocation_zone_metrics[metrics_checks.SloType.unused_nodes] = unused_nodes_snapshot["metric"]

        try:
            scheduling_errors_snapshot = self.harvester_snapshot_manager.read_last_snapshot_data(
                "yp_lite_allocation", allocation_zone_id,
                projection={
                    "data.nodes_up_scheduling_errors": True,
                    "data.nodes_up_count": True,
                    "data.monitoring_min_percent_of_scheduling_errors": True,
                    "_id": False,
                },
            )
        except harvester_snapshots_manager.SnapshotNotFoundError:
            self.logger.error(f"Can't find yp_lite_allocation/{allocation_zone_id} snapshot")
        else:
            if scheduling_errors_snapshot:
                allocation_zone_metrics[metrics_checks.SloType.scheduling_errors] = {
                    "scheduling_errors_on_up_nodes_count": scheduling_errors_snapshot["nodes_up_scheduling_errors"],
                    "total_up_nodes_count": scheduling_errors_snapshot["nodes_up_count"],
                    "monitoring_min_percent_of_scheduling_errors":
                        scheduling_errors_snapshot["monitoring_min_percent_of_scheduling_errors"],
                }

    def fetch_all_zones_metrics(self, allocation_zone_id=None) -> features.AllocationZonesMetrics:
        metrics_calculation_snapshot = self.harvester_snapshot_manager.read_last_snapshot("metrics_calculation")
        all_zones_metrics = metrics_calculation_snapshot.data

        daemonset_allocation_zones = self.config_interface.get_allocation_zones(yp_daemonsets_only=True)
        if allocation_zone_id:
            if allocation_zone_id in daemonset_allocation_zones:
                target_allocation_zones = [allocation_zone_id]
            else:
                target_allocation_zones = []
        else:
            target_allocation_zones = daemonset_allocation_zones

        if target_allocation_zones:
            with concurrent.futures.ThreadPoolExecutor(max_workers=len(target_allocation_zones)) as executor:
                for allocation_zone_id in target_allocation_zones:
                    executor.submit(self._fetch_additional_daemonset_features, allocation_zone_id, all_zones_metrics)

        return features.AllocationZonesMetrics(
            {
                allocation_zone_id: metrics_provider.AllZoneMetrics.from_dict(serialized_allocation_zone_metrics)
                for allocation_zone_id, serialized_allocation_zone_metrics
                in all_zones_metrics.items()
            },
            metrics_calculation_snapshot.label.ts,
        )

    def _update_allocation_zone_metrics_yasm_cache(self, _worker_id):
        self._save_allocation_zone_metrics_cache(self.fetch_all_zones_metrics().metrics)

    def _put_to_shared_stats(self, _id):
        with self.shared_dict_lock:
            stat_metrics = self.shared_dict["stat_metrics"].copy()
            for metric, count in self.stats_holder.items():
                if metric in stat_metrics:
                    stat_metrics[metric] += count
                else:
                    stat_metrics[metric] = count
            self.shared_dict["stat_metrics"] = stat_metrics

    def start_collect(self):
        import uwsgi
        uwsgi.register_signal(self.FLUSH_LOCAL_METRICS_TO_SHARED_MEMORY_SIGNAL, "workers", self._put_to_shared_stats)
        uwsgi.register_signal(57, "worker", self._update_yp_lite_pods_manager_monitoring_yasm_cache)
        uwsgi.register_signal(58, "worker", self._update_missing_nodes_yasm_cache)
        uwsgi.register_signal(59, "worker", self._update_allocation_zone_metrics_yasm_cache)
        # Put metrics from CH to shared dictionary when binary starts.
        if uwsgi.worker_id() == 0:
            try:
                self._update_yp_lite_pods_manager_monitoring_yasm_cache(0)
            except Exception as e:
                self.logger.error("Error updating YP Lite pods manager monitoring cache at launch: %s",
                                  e,
                                  exc_info=True)
            try:
                self._update_missing_nodes_yasm_cache(0)
            except Exception as e:
                self.logger.error("Error updating YP Lite pods manager monitoring cache at launch: %s",
                                  e,
                                  exc_info=True)
            try:
                self._update_allocation_zone_metrics_yasm_cache(0)
            except Exception as e:
                self.logger.error("Error updating allocation zone metrics cache at launch: %s",
                                  e,
                                  exc_info=True)
            self.logger.info("Done reading monitoring data to cache")
        uwsgi.add_timer(57, 60)
        uwsgi.add_timer(58, 60)
        uwsgi.add_timer(59, 60)

    def get_shared_stats(self):
        result = []
        import uwsgi  # type: ignore
        with self.shared_dict_lock:
            stat_metrics = self.shared_dict["stat_metrics"].copy()
            for metric, count in stat_metrics.items():
                result.append([metric, count])  # stats for previous request
                stat_metrics[metric] = 0
            self.shared_dict["stat_metrics"] = stat_metrics

        # stats for current state will be collected in _put_to_shared_stats() for next request
        uwsgi.signal(self.FLUSH_LOCAL_METRICS_TO_SHARED_MEMORY_SIGNAL)
        return result

    def get_local_stats(self):
        result = []
        for metric, count in self.stats_holder.items():
            result.append([metric, count])
        return result

    @contextlib.contextmanager
    def shared_metrics(self):
        with self.shared_dict_lock:
            yield self.shared_dict
