"""Project automation objects."""

import dataclasses
import typing as tp
from functools import partial

from sepelib.core import config
from sepelib.core.exceptions import LogicalError
from walle.errors import StateChanged
from walle.expert import rules, juggler
from walle.expert.automation import healing_automation
from walle.expert.automation_plot import AutomationPlot, AUTOMATION_PLOT_BASIC_ID, AUTOMATION_PLOT_FULL_FEATURED_ID
from walle.expert.checks import get_static_checks_to_configure, get_project_checks_to_configure
from walle.expert.decision import Decision
from walle.expert.rules.base import AbstractRule, CheckRuleInterface, apply_and_set_rule_name
from walle.expert.types import WalleAction, CheckType, CheckSets
from walle.projects import Project
from walle.util.gevent_tools import gevent_idle_iter


class AutomationSettingsChanged(StateChanged):
    """Raised in case when user changed automation settings in the way that does not allow us to proceed."""

    pass


def load_decision_makers(project_ids):
    """
    Load automation plots into cache. Map project to automation plots.
    """
    projects = Project.objects(id__in=project_ids).only(
        "automation_plot_id",
        "healing_automation",
        "automation_limits",
        "manually_disabled_checks",
        "tags",
    )
    # TODO(rocco66): forgotten enabled_checks?
    decision_makers_cache = DecisionMakerCache()
    return {p.id: decision_makers_cache.get_decision_maker(p) for p in gevent_idle_iter(projects)}


class AbstractDecisionMaker:
    checks = frozenset()
    enabled_checks = frozenset()

    def __init__(
        self,
        project: Project,
        enabled_checks: tp.Optional[tp.Iterable[str]] = None,
        disabled_checks: tp.Optional[tp.Iterable[str]] = None,
    ):
        self._project = project
        if enabled_checks is None:
            self.enabled_checks = self.checks
        else:
            self.enabled_checks = frozenset(enabled_checks) & self.checks
        self._disabled_checks = set(disabled_checks or [])

    def make_availability_decision(self, host, reasons: juggler.Reasons) -> Decision:
        rule = rules.AvailabilityCheckRule()
        decision = apply_and_set_rule_name(rule, host, reasons, CheckSets.BASIC - self._disabled_checks)

        if decision.action != WalleAction.HEALTHY:
            return decision

        # detach decision from active checks.
        return Decision.healthy(decision.reason)

    def make_decision(self, host, reasons: juggler.Reasons, fast=False) -> Decision:
        """Return decision for current host's health status."""
        raise NotImplementedError

    def make_alternate_decision(self, host, reasons, checks, checks_for_use=None) -> Decision:
        """Make decisions for given checks AND for active checks.
        Return decision that should actually trigger an action,
        e.g. try reboot host that is not available before failing it for missing passive checks.
        """

        raise NotImplementedError

    def make_decision_trace(self, host, reasons: juggler.Reasons) -> tp.Iterator[Decision]:
        raise NotImplementedError

    def all_available_checks(self):
        raise NotImplementedError

    def automation(self):
        return healing_automation(self._project.id)

    def _get_plot_id(self):
        raise NotImplementedError

    def checks_to_configure(self):
        return self.checks - self._disabled_checks

    def __eq__(self, other):
        if not isinstance(other, AbstractDecisionMaker):
            raise LogicalError

        if type(other) != type(self):
            return False

        return (
            other.checks == self.checks
            and other.enabled_checks == self.enabled_checks
            and other._disabled_checks == self._disabled_checks
            and other._get_plot_id() == self._get_plot_id()
            and other._project.id == self._project.id
        )


class BasicDecisionMaker(AbstractDecisionMaker):
    """Make decisions based on availability checks only."""

    checks = CheckSets.BASIC

    def make_decision(self, host, reasons: juggler.Reasons, fast=False) -> Decision:
        """Return decision for current host's health status."""
        return self.make_availability_decision(host, reasons)

    def make_decision_trace(self, host, reasons: juggler.Reasons) -> tp.Iterator[Decision]:
        """Yield availability decision for current host's health status."""
        yield self.make_availability_decision(host, reasons)

    def make_alternate_decision(self, host, reasons, checks, checks_for_use=None) -> Decision:
        """Make decisions for given checks AND for active checks.
        Return decision that should actually trigger an action,
        e.g. try reboot host that is not available before failing it for missing passive checks.
        """
        return self.make_availability_decision(host, reasons)

    def all_available_checks(self):
        return self.checks

    def _get_plot_id(self):
        return AUTOMATION_PLOT_BASIC_ID


def get_checks_priority(checks: tp.Optional[list[str]]) -> int:
    default_priority = config.get_value("expert_system.check_type_priorities.default")
    if not checks:
        return default_priority
    checks = checks or []
    return max(config.get_value(f"expert_system.check_type_priorities.{c}", default_priority) for c in checks)


def _decision_priority_key(decision: Decision) -> int:
    return -get_checks_priority(decision.checks)


class ModernDecisionMaker(AbstractDecisionMaker):
    """Make decision based on provided rules. Set of checks must be provided too."""

    def __init__(
        self,
        automation_plot_id,
        checks,
        check_rules,
        project: Project,
        enabled_checks: tp.Optional[tp.Iterable[str]] = None,
        disabled_checks: tp.Optional[tp.Iterable[str]] = None,
        automation_plot=None,
    ):
        self.checks = checks
        self._rules = check_rules
        self._automation_plot_id = automation_plot_id
        self._automation_plot = automation_plot
        self._mandatory_checks = self._get_mandatory_checks()

        super().__init__(project, enabled_checks, disabled_checks)

    def _get_mandatory_checks(self):
        if self._automation_plot:
            return {check.name for check in self._automation_plot.checks if check.enabled and check.wait}
        return frozenset()

    def make_decision(
        self, host, reasons: juggler.Reasons, enabled_checks: tp.Optional[frozenset[str]] = None, fast=False
    ) -> Decision:
        """Make decision according to Juggler checks.
        Return WalleAction.HEALTHY only if all checks (except for infrastructure checks) are passed.
        In fast mode: return first non-`healthy` decision if any.
        In non-fast mode: return first actionable decision, or a wait decision, if any.
        If no non-`healthy` decisions was produced, return `healthy` decision.
        """

        if enabled_checks is None:
            enabled_checks = self.enabled_checks

        all_decisions = (
            apply_and_set_rule_name(rule, host, reasons, enabled_checks - self._disabled_checks) for rule in self._rules
        )
        unhealthy_decisions = (d for d in all_decisions if d.action != WalleAction.HEALTHY)
        unhealthy_decisions_by_priority = sorted(unhealthy_decisions, key=_decision_priority_key)

        if fast:
            # fast-track: find first failing check and return its decision
            return next(iter(unhealthy_decisions_by_priority), Decision.healthy("Host is healthy."))

        decision = None
        for decision in unhealthy_decisions_by_priority:
            if decision.action != WalleAction.WAIT:
                return decision
        else:
            # Here `decision` is a last `wait` decision we met.
            return decision or Decision.healthy("Host is healthy.")

    def make_decision_trace(self, host, reasons: juggler.Reasons) -> tp.Iterator[Decision]:
        """Make decision trace according to Juggler checks.
        Yields decisions with ANY action
        Yields additional HEALTHY decision if all decisions were HEALTHY
        """
        # use generator expression here to avoid running all checks, run one by one until first failing check found.
        unhealthy_decision_count = 0
        for rule in self._rules:
            decision = apply_and_set_rule_name(rule, host, reasons, self.enabled_checks - self._disabled_checks)
            if decision.action != WalleAction.HEALTHY:
                unhealthy_decision_count += 1
            yield decision

        if not unhealthy_decision_count:
            yield Decision.healthy("Host is healthy.")

    def make_alternate_decision(self, host, reasons, checks, checks_for_use=None) -> Decision:
        """Make decisions for given checks AND for active checks.
        Return decision that should actually trigger an action,
        e.g. try reboot host that is not available before failing it for missing passive checks.
        """

        active_checks_decision = self.make_availability_decision(host, reasons)

        if checks.issubset(CheckSets.BASIC) and not self._mandatory_checks:
            return active_checks_decision

        if checks_for_use:
            enabled_checks = set(checks_for_use)
        else:
            enabled_checks = self.enabled_checks & set(checks) | self._mandatory_checks

        enabled_checks -= self._disabled_checks
        enabled_checks_decision = self.make_decision(host, reasons, enabled_checks=enabled_checks)

        if enabled_checks_decision.action == WalleAction.HEALTHY:
            return enabled_checks_decision

        if active_checks_decision.action == WalleAction.WAIT:
            result = Decision.wait(
                "{} But active checks are not passed yet: {}".format(
                    enabled_checks_decision.reason, active_checks_decision.reason
                ),
                checks=active_checks_decision.checks,
            )
            result.rule_name = active_checks_decision.rule_name
            return result

        if active_checks_decision.action != WalleAction.HEALTHY:
            return active_checks_decision

        return enabled_checks_decision

    def automation(self):
        return healing_automation(self._project.id, plot=self._automation_plot)

    def _get_plot_id(self):
        return self._automation_plot_id

    def all_available_checks(self):
        return self.checks | get_static_checks_to_configure() | get_project_checks_to_configure(self._project.id)

    def checks_to_configure(self):
        return self.all_available_checks() - self._disabled_checks

    @staticmethod
    def _filter_checks_for_project(project: Project, checks: set[str]) -> tp.Iterable[str]:
        if not project.has_infiniband():
            checks -= set(CheckType.ALL_IB)
        if not project.has_tor_link_rule():
            checks -= set(CheckType.TOR_LINK)
        return checks

    @classmethod
    def for_modern_automation_plot(cls, project: Project):
        """Make decisions based on Wall-E check bundle, use hw-watcher and some other custom checks."""
        checks = cls._filter_checks_for_project(project, CheckSets.FULL_FEATURED)
        check_rules = _init_check_rules(project)

        return partial(cls, AUTOMATION_PLOT_FULL_FEATURED_ID, checks, check_rules)

    @classmethod
    def for_automation_plot(cls, project: Project, automation_plot):
        """Make decisions based on Wall-E check bundle and checks from automation plot."""
        checks = cls._filter_checks_for_project(
            project,
            CheckSets.FULL_FEATURED | {check.name for check in automation_plot.checks if check.enabled},
        )
        check_rules = _init_check_rules(project, automation_plot)

        return partial(cls, automation_plot.id, checks, check_rules, automation_plot=automation_plot)


_RuleOrInterface = tp.Union[AbstractRule, CheckRuleInterface]


def _wrap_check_rule_if_needed(rule_or_interface: _RuleOrInterface) -> AbstractRule:
    if isinstance(rule_or_interface, AbstractRule):
        return rule_or_interface
    else:
        return rules.SingleCheckRule(rule_or_interface)


def _get_all_rules(project: Project, policy: dict) -> tp.Iterator[_RuleOrInterface]:
    disk_policy = policy.get("disk") or {}
    memory_policy = policy.get("memory") or {}
    fsck_policy = policy.get("fsck") or {}
    tained_kernel_policy = policy.get("tained_kernel") or {}

    # Attention: Be very careful with their order.
    yield from [
        rules.CheckMemory(**memory_policy),
        rules.CheckBmcIpmi(),
        rules.CheckBmcIpDns(),
        rules.CheckBmcBattery(),
        rules.CheckBmcVoltage(),
        rules.CheckBmcUnknown(),
        rules.CheckDiskCable(),
        rules.CheckDiskSmartCodesRule(),
        rules.CheckDiskBadBlocks(**disk_policy),
        rules.CheckDiskPerformance(**disk_policy),
        rules.CheckDisk(**disk_policy),
        rules.CheckSsdPerfLow(),
        rules.CheckLink(),
        rules.CheckCpuCaches(),
        rules.RackOverheatRule(),
        rules.CheckCpuThermalHWW(),
        rules.CheckCpu(),
        rules.CheckCpuCapping(),
        rules.CheckGpu(),
    ]
    if project.has_infiniband():
        yield rules.CheckInfiniband()
        yield rules.IbLinkRule()
    if project.has_tor_link_rule():
        yield rules.TorLinkRule()
    yield from [
        rules.CheckTaintedKernel(**tained_kernel_policy),
        rules.CheckReboots(),
        rules.FsckRule(**fsck_policy),
        rules.RackRule(),
        rules.AvailabilityCheckRule(),
        rules.MissingHwChecksRule(),
        rules.MissingPassiveChecksRule(),
    ]


def _init_check_rules(project: Project, automation_plot=None):
    if automation_plot:
        policy = config.get_value("automation.plot_policy.{}".format(automation_plot.id), {})
        automation_plot_checks = automation_plot.checks
    else:
        policy = {}
        automation_plot_checks = ()

    decision_making_rules = [_wrap_check_rule_if_needed(r) for r in _get_all_rules(project, policy)]
    prev_check_name = None
    for check in automation_plot_checks:
        if check.enabled:
            decision_making_rules.append(rules.configure_rule(check, prev_check_name))
            prev_check_name = check.name

    return decision_making_rules


class DecisionMakerCache:
    def __init__(self):
        self._cache = {}

    def get_decision_maker(
        self,
        project,
        enabled_checks: tp.Optional[tp.Iterable[str]] = None,
    ) -> AbstractDecisionMaker:
        decision_maker_cache_id = _DecisionMakerCacheKey.from_project(project)
        decision_maker_cls = self._cache.get(decision_maker_cache_id)
        if not decision_maker_cls:
            plot_id = decision_maker_cache_id.automation_plot_id
            if plot_id == AUTOMATION_PLOT_BASIC_ID:
                decision_maker_cls = BasicDecisionMaker
            elif plot_id == AUTOMATION_PLOT_FULL_FEATURED_ID:
                # don't fetch it from database, we know it does not contain any checks.
                decision_maker_cls = ModernDecisionMaker.for_modern_automation_plot(project)
            else:
                automation_plot = AutomationPlot.objects.get(id=plot_id)
                decision_maker_cls = ModernDecisionMaker.for_automation_plot(project, automation_plot)
            self._cache[decision_maker_cache_id] = decision_maker_cls
        return decision_maker_cls(
            project, enabled_checks=enabled_checks, disabled_checks=project.manually_disabled_checks
        )


def get_decision_maker(*args, **kwargs):
    return DecisionMakerCache().get_decision_maker(*args, **kwargs)


@dataclasses.dataclass(frozen=True, eq=True)
class _DecisionMakerCacheKey:
    automation_plot_id: str
    has_infiniband: bool

    @staticmethod
    def from_project(project: Project):
        return _DecisionMakerCacheKey(project.automation_plot_id or AUTOMATION_PLOT_BASIC_ID, project.has_infiniband())
