"""Start healing tasks for decisions made in screening module."""
import logging
import typing as tp
from itertools import chain

import mongoengine
from gevent.event import Event
from gevent.pool import Pool

from sepelib.core import config
from sepelib.core.exceptions import Error
from walle.application import app
from walle.expert import constants as expert_constants, juggler, dmc
from walle.expert.decision import Decision
from walle.expert.decisionmakers import load_decision_makers
from walle.expert.types import WalleAction
from walle.hbf_drills import drills_cache
from walle.hosts import (
    Host,
    HostState,
    HostStatus,
    HostDecision,
    DecisionStatus,
    HostMessage,
    get_raw_query_for_dmc_rules,
)
from walle.host_shard import HostShard, HostShardType
from walle.stats import stats_manager as stats
from walle.util import mongo, counter
from walle.util.cloud_tools import get_tier
from walle.util.gevent_tools import gevent_idle_generator
from walle.util.host_health import get_human_reasons
from walle.util.misc import (
    add_interval_job,
    parallelize_processing,
    StopWatch,
    closing_ctx,
    iter_shuffle,
    iter_chunks,
    get_expert_shards_count,
)

log = logging.getLogger(__name__)

PARALLEL_HOST_PROCESSING_POOL_SIZE = 15
TRIAGE_NAME_TEMPLATE = "DMC triage watcher tier-{}"

_TRIAGE = None

_health_decision_counter_checker = counter.CounterChecker("fresh python decision for host")
_go_decision_counter_checker = counter.CounterChecker("fresh go decision for host")


def start(scheduler, partitioner: mongo.MongoPartitionerService):
    parallel_shard = config.get_value("expert_system.parallel_shards", default=1)
    global _TRIAGE
    _TRIAGE = Triage()

    def job(concurrency=parallel_shard):
        stopwatch = StopWatch()
        _TRIAGE.run(partitioner, concurrency)
        stats.add_sample(("dmc", "triage", "iteration_time"), stopwatch.get())

    name = TRIAGE_NAME_TEMPLATE.format(get_tier())
    add_interval_job(scheduler, job, name=name, interval=expert_constants.HEALTH_SCREENING_PERIOD)


def stop():
    if _TRIAGE:
        _TRIAGE.stop()


def _select_decision(python_decision: Decision, go_decision: tp.Optional[Decision]) -> Decision:
    if not go_decision:
        return python_decision

    switched_rules = app.settings().dmc_rules_switched
    if go_decision.rule_name in switched_rules:
        return go_decision
    if python_decision.rule_name in switched_rules:
        return python_decision

    if go_decision.action in [WalleAction.WAIT, WalleAction.HEALTHY]:
        return python_decision

    # NOTE(rocco66): go_decision is failed
    if python_decision.action in [WalleAction.WAIT, WalleAction.HEALTHY]:
        return go_decision

    # NOTE(rocco66): and python_decision is failed too
    return python_decision


def _check_go_decision(host_decision: tp.Optional[HostDecision]) -> tp.Optional[Decision]:
    if host_decision and _go_decision_counter_checker.check_for_fresh(host_decision.uuid, host_decision):
        return Decision(**host_decision.decision.to_dict())


def _get_go_decisions(hosts: list[Host]) -> dict[str, HostDecision]:
    return {hd.uuid: hd for hd in HostDecision.objects.filter(uuid__in=[h.uuid for h in hosts])}


class TriageShardProcessor:
    def __init__(self, total_shards_count):
        self.dmc_processors = Pool(size=PARALLEL_HOST_PROCESSING_POOL_SIZE)
        self.stopped_event = Event()
        self._total_shards_count = total_shards_count

    def stop(self):
        self.stopped_event.set()
        self.dmc_processors.kill()

    def process_shard(self, shard: mongo.MongoPartitionerShard):
        if self.stopped_event.is_set():
            return

        try:
            self._triage(shard)
        except Exception as e:
            log.exception("Failed to triage hosts for #%s shard: %s", shard, e)
        else:
            HostShard.processed(shard.id, HostShardType.triage)

    def _triage(self, shard: mongo.MongoPartitionerShard):
        if self.stopped_event.is_set():
            return
        log.info("Triage hosts decisions for #%s shard...", shard)
        main_stopwatch = StopWatch()
        host_count = 0

        with shard.lock:
            log.info("Executing decisions...")
            shard_query = mongo.get_host_mongo_shard_query(shard, self._total_shards_count)
            decision_makers = _fetch_host_decision_makers(shard_query)

            hbf_drills_collection = drills_cache.get()
            with closing_ctx(self.dmc_processors.join, raise_error=True):
                host_decision_gen = _fetch_hosts_and_health_for_decision_makers(decision_makers)
                for host, reasons, go_host_decision, decision_maker in host_decision_gen:
                    if self.stopped_event.is_set():
                        return
                    if not self._is_decision_fresh(host):
                        continue
                    self.dmc_processors.spawn(
                        self._process_and_save_message,
                        host,
                        reasons,
                        go_host_decision,
                        decision_maker,
                        hbf_drills_collection,
                    )
                    host_count += 1

        log.info("All actionable decisions has been executed.")
        stats.add_sample(("dmc", "triage", "total_processing_time"), main_stopwatch.get())
        stats.add_sample(("dmc", "triage", "host_count"), host_count)
        if host_count:
            stats.add_sample(("dmc", "triage", "host_time_share"), main_stopwatch.get() // host_count)
        log.info("Triage performed for %s hosts of #%s shard...", host_count, shard)

    def _process_and_save_message(self, host, reasons, go_host_decision: HostDecision, decision_maker, hbf_drills):
        """
        :type hbf_drills: HbfDrillsCollection
        """
        if self.stopped_event.is_set():
            return
        stopwatch = StopWatch()
        if reasons and host.health.decision:
            python_decision = Decision(**host.health.decision.to_dict())
            decision = _select_decision(python_decision, _check_go_decision(go_host_decision))

            # # host can be excluded from processing if it is involved in HBF drill right now
            # message = _get_hbf_drill_inclusion_reason(host, decision, hbf_drills)
            # if message is None:
            human_reasons = get_human_reasons(reasons)
            message = _process_with_dmc(host, decision, decision_maker, human_reasons)
        elif not reasons:
            log.info("Will not process host %s on this iteration: it's health is missing.", host.human_name())
            message = "Will not process host on this iteration: it's health is missing."
        else:
            # decision is missing means host was updated while it was under a task.
            # we should just to wait for the next run of screening.
            log.info("Will not process host %s on this iteration: it has not been screened yet.", host.human_name())
            message = "Will not process host on this iteration: it has not been screened yet."

        host.set_messages(dmc=[HostMessage.info(message)] if message else None)
        stats.add_sample(("dmc", "triage", "host_processing_time"), stopwatch.get())

    @staticmethod
    def _is_decision_fresh(host):
        if not config.get_value("expert_system.use_decision_counter", default=False):
            return True
        if not host.health or not _health_decision_counter_checker.check_for_fresh(host.inv, host.health.decision):
            return False
        return True


class Triage:
    def __init__(self) -> None:
        self.pool: tp.Optional[Pool] = None
        self.shard_processors: tp.Optional[tp.List[TriageShardProcessor]] = None
        # TODO why None?
        self.partitioner: tp.Optional[mongo.MongoPartitionerService] = None
        self.stopped_event = Event()
        self._total_shards_count = get_expert_shards_count(get_tier())

    def stop(self):
        self.stopped_event.set()
        self.partitioner.stop()
        for processor in self.shard_processors:
            processor.stop()
        self.pool.kill()

    def run(self, partitioner: mongo.MongoPartitionerService, concurrency: int):
        if self.stopped_event.is_set():
            return
        self.partitioner = partitioner
        self.pool = Pool(size=concurrency)
        shards = self.partitioner.get_numeric_shards(self._total_shards_count)
        if self.stopped_event.is_set():
            return
        log.info(f"DMC triage iteration started, shards: {', '.join(str(s) for s in shards)}")
        self.shard_processors = [TriageShardProcessor(self._total_shards_count) for _ in range(len(shards))]
        parallelize_processing(
            self.execute_shard_processing,
            list(zip(shards, self.shard_processors)),
            threads=concurrency,
            pool=self.pool,
        )
        log.info("DMC triage iteration completed")

    def execute_shard_processing(self, shard_proc: tuple[mongo.MongoPartitionerShard, TriageShardProcessor]):
        if self.stopped_event.is_set():
            return
        shard, processor = shard_proc
        processor.process_shard(shard)


@gevent_idle_generator
def _fetch_host_decision_makers(shard_query):
    """Yield pairs (host_name, decision_maker).
    May be converted into dict by dict(_fetch_host_decision_makers()).
    """
    dmc_raw_rules = get_raw_query_for_dmc_rules()
    hosts = list(_get_host_iterator(shard_query, full=False, raw_rules=dmc_raw_rules))
    project_ids = {host.project for host in hosts}
    project_decision_makers = load_decision_makers(project_ids)

    for host in iter_shuffle(hosts):
        yield host.name, project_decision_makers[host.project]


def _get_host_iterator(query, full, raw_rules=None):
    """
    :param query: additional filter for mongoengine
    :param full: return all fields required for dmc processing (if false, return only inv, name and project)
    :return: iter(Hosts)
    """
    if not raw_rules:
        raw_rules = {}

    return chain(
        _get_hosts_by_query(
            query
            & mongoengine.Q(
                decision_status=DecisionStatus.FAILURE,
                status__in=HostStatus.ALL_STEADY,
            ),
            full=full,
            raw_rules=raw_rules,
        ),
        _get_hosts_by_query(
            query
            & mongoengine.Q(
                decision_status=DecisionStatus.HEALTHY,
                status=HostStatus.DEAD,
            ),
            full=full,
            raw_rules=raw_rules,
        ),
    )


@gevent_idle_generator
def _get_hosts_by_query(query, full, raw_rules=None):
    # NB: this is an iterator, suitable for `chain` but not reusable, convert it to list before using!
    if not raw_rules:
        raw_rules = {}

    if full:
        fields = (
            "inv",
            "name",
            "project",
            "status",
            "status_author",
            "status_audit_log_id",
            "state",
            "tier",
            "health.reasons",
            "health.decision",
            "decision_status",
            "decision_status_timestamp",
            "checks_min_time",
            "restrictions",
            "location.short_queue_name",
            "location.rack",
            "location.unit",
            "provisioner",
            "config",
            "deploy_tags",
            "deploy_network",
            "deploy_config_policy",
            "platform",
            "location.short_datacenter_name",
            "ips",
            "type",
        )
    else:
        fields = ("inv", "name", "project", "tier")

    return Host.objects(query & mongoengine.Q(state__in=HostState.ALL_ASSIGNED), __raw__=raw_rules).only(*fields)


@gevent_idle_generator
def _fetch_hosts_and_health_for_decision_makers(hosts_decision_makers):
    # suppose chunk size is big enough to get some speed improvements (over fetching every single host),
    # and small enough to not be a slowdown (by having lots of hosts waiting for their triage to arrive)
    refresh_chunk_size = 30

    for decision_makers in map(dict, iter_chunks(hosts_decision_makers, refresh_chunk_size)):
        hosts = list(_get_host_iterator(mongoengine.Q(name__in=list(decision_makers)), full=True))
        go_decisions = _get_go_decisions(hosts)
        checks_min_times = {host.name: host.checks_min_time for host in hosts}
        # filter decision_makers by host
        decision_makers = {host.name: decision_makers[host.name] for host in hosts}
        try:
            host_reasons = juggler.get_health_reasons_for_hosts(decision_makers, checks_min_times)
        except Error as e:
            log.error("Failed to get checks for hosts %s for triage: %s", list(decision_makers), e)
            continue

        for host in hosts:
            yield host, host_reasons.get(host.name), go_decisions.get(host.uuid), decision_makers[host.name]


def _get_hbf_drill_inclusion_reason(host, decision, hbf_drills):
    message = None
    if decision.action != WalleAction.HEALTHY:
        message = hbf_drills.get_host_inclusion_reason(host)
        if message is not None:
            log.info("Host %s is temporarily excluded from health processing: %s", host.name, message)
            message = "Host {} is temporarily excluded from health processing: {}".format(host.name, message)
    return message


def _process_with_dmc(host, decision, decision_maker, human_reason):
    if dmc.dmc_allowed(host):
        try:
            return dmc.handle_decision(host, decision, human_reason, decision_maker)
        except Exception as e:
            log.exception("Failed to process %s decision for host %s: %s", decision.action, host.human_id(), e)
            return "Failed to process {} decision: {}".format(decision.action, str(e))
