"""Listens to Expert System events."""

import logging
import typing as tp

import mongoengine
from gevent import Greenlet
from gevent.event import Event
from gevent.pool import Pool

from sepelib.core import config
from sepelib.core.exceptions import Error
from walle.expert import constants as expert_constants, dmc, juggler
from walle.expert.decisionmakers import load_decision_makers
from walle.expert.types import WalleAction, CheckStatus
from walle.host_shard import HostShard, HostShardType
from walle.hosts import Host, HostState, HostStatus, DecisionStatus, HostMessage, get_raw_query_for_dmc_rules
from walle.models import timestamp
from walle.stats import stats_manager as stats, IntegerLinearHistogram
from walle.util import mongo
from walle.util.cloud_tools import get_tier
from walle.util.gevent_tools import gevent_idle_iter
from walle.util.misc import (
    add_interval_job,
    parallelize_processing,
    StopWatch,
    fix_mongo_batch_update_dict,
    get_expert_shards_count,
)
from walle.util.mongo.bulk_group import SingleCollectionBulkGroup

log = logging.getLogger(__name__)

SCREENING_NAME_TEMPLATE = "DMC screening watcher tier-{}"

_SCREENING = None


def start(scheduler, partitioner):
    parallel_shard = config.get_value("expert_system.parallel_shards", default=1)
    global _SCREENING
    _SCREENING = Screening()

    def job(concurrency=parallel_shard):
        try:
            stopwatch = StopWatch()
            _SCREENING.run(concurrency, partitioner)
            stats.add_sample(("dmc", "screening", "iteration_time"), stopwatch.get())
        except Exception:
            log.exception("Screening uncaught error")
            raise

    name = SCREENING_NAME_TEMPLATE.format(get_tier())
    add_interval_job(scheduler, job, name=name, interval=expert_constants.HEALTH_SCREENING_PERIOD)


def stop():
    if _SCREENING:
        _SCREENING.stop()


class ScreeningShardProcessor:
    def __init__(self, total_shards_count) -> None:
        self._stopped_event = Event()
        self._total_shards_count = total_shards_count
        self._host_updater = HostBulkUpdater()
        self.greenlet: tp.Optional[Greenlet] = None

    def stop(self):
        self._stopped_event.set()
        self._host_updater.stop()

    def process_shard(self, shard: mongo.MongoPartitionerShard):
        if self._stopped_event.is_set():
            return

        with shard.lock:
            try:
                self._receive_health(shard)
            except Exception as e:
                log.exception("Failed to update host health status from Juggler for #%s shard: %s", shard, e)

    def _receive_health(self, shard: mongo.MongoPartitionerShard):
        if self._stopped_event.is_set():
            return
        log.info("Screening host health statuses for #%s shard...", shard)
        main_stopwatch = StopWatch()

        shard_query = mongo.get_host_mongo_shard_query(shard, self._total_shards_count)
        hostname_to_decision_maker, hostname_to_host = _fetch_shard_hosts(shard_query)

        shard_hosts_num = len(hostname_to_host)
        stats.add_sample(("dmc", "fetcher", "shard_size"), shard_hosts_num)
        if shard_hosts_num == 0:
            log.info("Shard #%s is empty.", shard)
            HostShard.processed(shard.id, HostShardType.screening)
            return

        log.debug("Receiving health for %s hosts in #%s shard...", shard_hosts_num, shard)

        try:
            host_health = juggler.get_health_for_hosts(hostname_to_decision_maker)
        except Error as e:
            log.exception("Failed to update host health statuses for shard #%s from Juggler: %s", shard, e)
            hosts = Host.objects(
                mongoengine.Q(name__in=list(hostname_to_host.keys()), state__in=HostState.ALL_ASSIGNED) & shard_query
            )
            for host in gevent_idle_iter(hosts):
                host.update(unset__health=True)
            return

        log.debug("Updating host health statuses from Juggler for #%s shard...", shard)
        self._process_health_data(host_health, hostname_to_host, hostname_to_decision_maker)
        HostShard.processed(shard.id, HostShardType.screening)

        stats.add_sample(("dmc", "fetcher", "host_count"), shard_hosts_num)
        stats.add_sample(("dmc", "fetcher", "shard_time"), main_stopwatch.get())
        stats.add_sample(("dmc", "fetcher", "shard_ratio"), main_stopwatch.get() // shard_hosts_num)

        log.info("Host health statuses have been updated from Juggler for #%s shard.", shard)

    def _process_health_data(
        self, hosts_health: tp.Dict[str, juggler.HostHealth], hostname_to_host, hostname_to_decision_maker
    ):
        if self._stopped_event.is_set():
            return
        with self._host_updater as host_updater:
            for name, host_health in hosts_health.items():
                host = hostname_to_host[name]
                decision_maker = hostname_to_decision_maker[name]

                _collect_stats(host_health.event_time)
                _process_host(host_updater, host, host_health, decision_maker)

            missing_hosts = set(hostname_to_host) - set(hosts_health)
            _drop_health_status(host_updater, hostname_to_host, missing_hosts)
            stats.increment_counter(("dmc", "fetcher", "missing"), len(missing_hosts))


class Screening:
    def __init__(self):
        self._screening_pool = Pool()
        self._screening_processors: tp.Optional[tp.List[ScreeningShardProcessor]] = None
        self._stopped_event = Event()
        self.greenlet: tp.Optional[Greenlet] = None
        self._total_shards_count = get_expert_shards_count(get_tier())

    def run(self, concurrency, partitioner):
        if self._stopped_event.is_set():
            return

        shards = partitioner.get_numeric_shards(self._total_shards_count)
        if self._stopped_event.is_set():
            return

        log.info(f"DMC screening iteration started, shards: {', '.join(str(s) for s in shards)}")
        self._screening_processors = [ScreeningShardProcessor(self._total_shards_count) for _ in range(len(shards))]
        parallelize_processing(
            self.execute_processing,
            list(zip(shards, self._screening_processors)),
            threads=concurrency,
            pool=self._screening_pool,
        )
        log.info("DMC screening iteration completed")

    def execute_processing(self, shard_proc: tuple[mongo.MongoPartitionerShard, ScreeningShardProcessor]):
        if self._stopped_event.is_set():
            return
        shard, proc = shard_proc
        proc.process_shard(shard)

    def stop(self):
        self._stopped_event.set()
        log.info("Stopping screening processors")
        for processor in self._screening_processors:
            processor.stop()
        log.info("Stopping screening pool")
        self._screening_pool.kill()

        log.info("Screening is stopped")


class HostBulkUpdater:
    _BULK_SIZE = 500
    _PARALLEL_BULKS = 10

    _bulk_group = None

    def __init__(self):
        self._stopped_event = Event()
        self._bulk_group = SingleCollectionBulkGroup(
            collection=Host.get_collection(),
            stats_key=("dmc", "fetcher", "bulk_host_updater"),
            bulk_size_limit=self._BULK_SIZE,
            parallel_bulks_limit=self._PARALLEL_BULKS,
        )

    def stop(self):
        self._stopped_event.set()
        self._bulk_group.stop()

    def store(self, hostname, health, decision_status, decision_time, has_task, message):
        if self._stopped_event.is_set():
            return

        if health is None:
            return self.reset_health(hostname)

        update = {
            "$set": {
                "health": health.to_mongo(),
                "decision_status": decision_status,
                "decision_status_timestamp": decision_time,
            }
        }

        if message:
            update["$set"]["messages.dmc"] = [{"severity": HostMessage.SEVERITY_INFO, "message": message}]
        update = fix_mongo_batch_update_dict(update)
        with self._bulk_group.current() as bulk:
            bulk.find(
                {
                    "name": hostname,
                    "state": {"$in": HostState.ALL_ASSIGNED},
                    "status": {"$ne": HostStatus.INVALID},
                    "task": {"$exists": has_task},
                    "$or": [
                        {"health": {"$exists": False}},
                        {
                            "$or": [
                                {"decision_status_timestamp": {"$lte": decision_time}},
                                {"decision_status_timestamp": {"$exists": False}},
                            ]
                        },
                    ],
                }
            ).update_one(update)

    def reset_health(self, hostname):
        with self._bulk_group.current() as bulk:
            bulk.find(
                {
                    "name": hostname,
                    "health": {"$exists": True},
                }
            ).update_one({"$unset": {"health": True, "messages.dmc": True}})

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._bulk_group.finish()


def _fetch_shard_hosts(shard_query):
    fields = (
        "inv",
        "name",
        "project",
        "status",
        "status_author",
        "status_audit_log_id",
        "state",
        "health.decision",
        "decision_status",
        "decision_status_timestamp",
        "platform",
        "checks_min_time",
        "restrictions",
        "location.short_queue_name",
        "location.rack",
        "location.unit",
        "task.task_id",
    )

    dmc_raw_rules = get_raw_query_for_dmc_rules()

    hosts = Host.objects(mongoengine.Q(state__in=HostState.ALL_ASSIGNED) & shard_query, __raw__=dmc_raw_rules).only(
        *fields
    )

    hostname_to_host = {}
    project_ids = set()

    for host in gevent_idle_iter(hosts):
        hostname_to_host[host.name] = host
        project_ids.add(host.project)

    project_decision_makers = load_decision_makers(project_ids)
    hostname_to_decision_maker = {}

    for host in gevent_idle_iter(hostname_to_host.values()):
        hostname_to_decision_maker[host.name] = project_decision_makers[host.project]

    return hostname_to_decision_maker, hostname_to_host


def _collect_metrics(decision_status, old_decision_status, host_health, stopwatch):
    if decision_status != old_decision_status:
        if decision_status == DecisionStatus.HEALTHY:
            age = host_health.most_fresh_walle_check_age(CheckStatus.PASSED)
            if age is not None:
                stats.add_sample(("dmc", "time_to_health"), age)
        elif decision_status == DecisionStatus.FAILURE:
            age = host_health.most_fresh_walle_check_age(CheckStatus.FAILED)
            if age is not None:
                stats.add_sample(("dmc", "time_to_failure"), age)

    stats.add_sample(("dmc", "fetcher", "host_processing_time"), stopwatch.split())


def _process_host(host_updater, host, host_health: juggler.HostHealth, decision_maker) -> None:
    stopwatch = StopWatch()

    # dmc is not allowed for hosts that are currently under repair.
    try:
        has_task = host.task.task_id is not None
    except AttributeError:
        has_task = False

    decision = None
    if dmc.dmc_allowed(host):
        decision = decision_maker.make_decision(host, host_health.current_reasons)

    _set_decision_to_health(host_health.health, decision, old_health=host.health)
    old_decision_status = host.decision_status
    old_decision_status_timestamp = host.decision_status_timestamp
    if decision is None:
        decision_status = old_decision_status
        decision_timestamp = old_decision_status_timestamp
    else:
        decision_status = _decision_status(decision, old_decision_status)
        decision_timestamp = _decision_timestamp(decision_status, old_decision_status, old_decision_status_timestamp)

    message = _decision_message(decision_status)
    host_updater.store(host.name, host_health.health, decision_status, decision_timestamp, has_task, message)
    _collect_metrics(decision_status, old_decision_status, host_health, stopwatch)


def _decision_message(decision_status):
    if decision_status == DecisionStatus.HEALTHY:
        # DMC couldn't set it's own message for these hosts because it won't see them.
        return "Host is healthy."
    else:
        return None


def _set_decision_to_health(health, decision, old_health=None):
    old_decision = getattr(old_health, "decision", None)
    if decision:
        decision.counter = old_decision.counter + 1 if old_decision and old_decision.counter is not None else 0
        health.decision = decision.to_dict()
        return

    if old_decision:
        health.decision = old_decision


def _decision_status(decision, old_decision_status):
    if decision.action == WalleAction.WAIT:
        # Try to swallow flaps and network problems, changing action to WAIT and back does not change the status.
        is_failed = old_decision_status == DecisionStatus.FAILURE
    else:
        is_failed = decision.action != WalleAction.HEALTHY

    if is_failed:
        return DecisionStatus.FAILURE
    else:
        return DecisionStatus.HEALTHY


def _decision_timestamp(new_decision_status, old_decision_status, old_decision_timestamp):
    if new_decision_status != old_decision_status:
        return timestamp()
    else:
        return old_decision_timestamp


def _collect_stats(event_time):
    stats.increment_counter(("dmc", "fetcher", "received"))
    stats.add_sample(
        ("dmc", "fetcher", "health_store_delay"), timestamp() - event_time, hist_cls=IntegerLinearHistogram
    )


def _drop_health_status(host_updater, host_info, hosts):
    for host in gevent_idle_iter(hosts):
        if host_info[host].health is not None:
            host_updater.reset_health(host)
