"""Shared code for availability checks."""

import logging
from itertools import chain

from sepelib.core.exceptions import LogicalError
from walle.admin_requests.constants import RequestTypes
from walle.clients.eine import ProfileMode
from walle.constants import DEFAULT_DNS_TTL, HostType
from walle.expert.constants import NETWORK_CHECKS_REACTION_TIMEOUT
from walle.expert.decision import Decision
from walle.models import timestamp
from walle.operations_log.constants import Operation
from walle.util.misc import unique
from .base import AbstractRule, check_common
from .escalation import (
    EscalationRules,
    EscalationPoint,
    action_match,
    limit_reached,
    task_has_not_helped,
    automatic_profile_not_supported,
    escalate_to_profile,
    escalate_to_redeploy,
    escalate_to_deactivate,
    Predicate,
    escalate_to_dangerous_highload_test,
    escalate_to_second_time_node_report,
    operation_repeated,
    dns_automation_off,
    EscalationReason,
)
from .utils import get_check_result, is_disabled_check
from ..types import WalleAction, CheckType, CheckStatus, get_walle_check_type

log = logging.getLogger(__name__)


@Predicate
def _is_final(decision):
    return decision.get_param("final", False)


@Predicate
def _is_extra_highload(decision):
    return decision.get_param("profile_mode") == ProfileMode.EXTRA_HIGHLOAD_TEST


@Predicate
def _is_dangerous_highload(decision):
    return decision.get_param("profile_mode") == ProfileMode.DANGEROUS_HIGHLOAD_TEST


@Predicate
def _is_2nd_time_node(decision):
    return (
        decision.action == WalleAction.REPAIR_HARDWARE
        and decision.params.get("request_type") == RequestTypes.SECOND_TIME_NODE.type
    )


@EscalationReason
def _2tn_disabled_for_bad_projects(host, decision):
    # all sandbox projects, see WALLE-3624
    if host.project in {"sandbox", "sandbox-mtn", "sandbox-legacy", "sandbox-safe-hbf"}:
        return "Reports to EXP disabled for project {} per request in WALLE-3624.".format(host.project)


def _escalate_to_final_profile(decision, reason):
    return decision.escalate(WalleAction.PROFILE, reason, params={"final": True})


class CheckTypeDescrBase:
    check_type = None
    checks_human_name = ""
    healthy_msg = ""
    suspected_msg = ""


class AvailabilityCheckRuleBase(AbstractRule):
    check_type_descr = CheckTypeDescrBase

    escalation_rules = None

    @classmethod
    def escalate(cls, host, decision):
        return cls.escalation_rules.escalate(host, decision)

    @classmethod
    def make_decision(cls, host, reasons, enabled_checks):
        raise NotImplementedError()

    @classmethod
    def get_failed_checks(cls, host, infrastructure, check_results):
        # collect host checks failures.
        # Mark checks as suspected if we have infrastructure failures or dns record changes.

        failed_checks, suspected_checks = {}, {}
        for check_type, result in check_results:
            if result["status"] in {CheckStatus.PASSED, CheckStatus.FAILED} and cls.delay_for_dns(result, host):
                suspected_checks[check_type] = dict(result, delay_reason="dns ttl after records update")
                continue

            if (
                result["status"] == CheckStatus.FAILED
                and host.type == HostType.SERVER
                and cls.delay_for_infrastructure(result, infrastructure)
            ):
                suspected_checks[check_type] = dict(result, delay_reason=infrastructure.reasons)
                continue

            if result["status"] == CheckStatus.SUSPECTED:
                suspected_checks[check_type] = result
                continue

            if result["status"] == CheckStatus.FAILED:
                failed_checks[check_type] = result
                continue

        return failed_checks, suspected_checks

    @staticmethod
    def delay_for_infrastructure(result, infrastructure):
        if infrastructure.failed:
            # infrastructure failure, mark check as suspected instead of failure
            return True

        if infrastructure.missing:
            return True

        if infrastructure.needs_to_catch_up(result["status_mtime"]):
            # We may have infrastructure failures, but don't know about it yet. Give it some time.
            return True

        # Ensure we have received health update after infrastructure recovered.
        if infrastructure.needs_to_recover(result["effective_timestamp"]):
            return True

        return False

    @staticmethod
    def delay_for_dns(result, host):
        if host.dns is None or host.dns.update_time is None:
            return False

        if result["status"] == CheckStatus.FAILED and result["status_mtime"] > host.dns.update_time:
            # check failed after dns fix
            return False

        if host.dns.update_time + DEFAULT_DNS_TTL > timestamp():
            # wait for ttl
            return True

        return False

    @classmethod
    def _decision_wait_more(cls, suspected_checks, check_type, failure_type=None):
        message = cls.check_type_descr.suspected_msg
        reasons = ", ".join(walle_check_names(suspected_checks))
        delay_reasons = " or ".join(all_delay_reasons(suspected_checks))

        if delay_reasons:
            message = "{}: {}, which might be due to {}.".format(message, reasons, delay_reasons)
        else:
            message = "{}: {}.".format(message, reasons)

        checks = list(set(suspected_checks) & set(check_type))
        return Decision.wait(message, checks=checks, failure_type=failure_type)

    @classmethod
    def _common_decision(cls, checks_results):
        result_decision = Decision.healthy(cls.check_type_descr.healthy_msg, checks=cls.check_type_descr.check_type)
        for check_type, result in checks_results:
            decision = check_common(check_type, result)

            if decision.action == WalleAction.HEALTHY:
                pass
            elif decision.action == WalleAction.WAIT:
                result_decision = decision
            else:
                raise LogicalError()

        return result_decision

    @staticmethod
    def get_infrastructure(reasons):
        return _Infrastructure(reasons)


class _Infrastructure:
    def __init__(self, reasons):
        self._failure_checks = []
        self._missing_checks = []
        self._ok_times = []

        results = [(check_type, get_check_result(reasons, check_type)) for check_type in CheckType.ALL_INFRASTRUCTURE]
        self._analyze(results)

    def _analyze(self, results):
        for check_type, result in results:
            status = result["status"]
            if status == CheckStatus.PASSED:
                self._ok_times.append(result["status_mtime"])
            elif status in CheckStatus.ALL_MISSING:
                self._missing_checks.append((check_type, status))
            else:
                self._failure_checks.append((check_type, status))

    @property
    def failed(self):
        return len(self._failure_checks) > 0

    @property
    def missing(self):
        return len(self._missing_checks) > 0

    @property
    def reasons(self):
        info = [
            failure_reason(check_type, status)
            for check_type, status in chain(self._missing_checks, self._failure_checks)
        ]
        return "an infrastructure issue: {}".format(",".join(info))

    def needs_to_catch_up(self, check_status_mtime):
        if timestamp() - check_status_mtime <= NETWORK_CHECKS_REACTION_TIMEOUT:
            self._failure_checks.append((CheckType.NETMON, "waiting"))
            return True

        return False

    def needs_to_recover(self, check_effective_timestamp):
        if self._ok_times and check_effective_timestamp <= max(self._ok_times):
            self._failure_checks.append((CheckType.NETMON, "recovering"))
            return True

        return False


def failure_reason(check_type, status):
    return get_walle_check_type(check_type) + " " + status


def _enabled_check_types(host, check_types, enabled_checks):
    for check_type in check_types:
        if not is_disabled_check(host, check_type, enabled_checks):
            yield check_type


def get_all_results(check_types, host, reasons, enabled_checks):
    enabled_check_types = _enabled_check_types(host, check_types, enabled_checks)
    return [(check_type, get_check_result(reasons, check_type)) for check_type in enabled_check_types]


def all_failure_reasons(failed_checks):
    return [failure_reason(check_type, check["status"]) for check_type, check in failed_checks.items()]


def all_delay_reasons(suspected_checks):
    delay_reasons = (check.get("delay_reason") for check in suspected_checks.values())
    return unique(delay_reasons, clear=True)


def walle_check_names(checks):
    return list(map(get_walle_check_type, checks))


def get_escalation_rules(final_escalation=escalate_to_second_time_node_report):
    escalation_rules = EscalationRules(
        EscalationPoint(
            predicate=action_match(WalleAction.REBOOT),
            reason=task_has_not_helped(
                Operation.POWER_ON.host_status, "Host powering on hasn't helped", "Host failed to power on"
            ),
            action=escalate_to_profile,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.REBOOT),
            reason=task_has_not_helped(Operation.REBOOT.host_status, "Reboot hasn't helped", "Host failed to reboot"),
            action=escalate_to_profile,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.REBOOT),
            reason=task_has_not_helped(
                Operation.PROFILE.host_status, "Profiling hasn't helped", "Host profiling failed"
            ),
            action=escalate_to_redeploy,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.REBOOT),
            reason=task_has_not_helped(
                Operation.REDEPLOY.host_status, "Redeploying hasn't helped", "Host failed to redeploy"
            ),
            action=escalate_to_dangerous_highload_test,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.REBOOT),
            reason=task_has_not_helped(
                Operation.CHANGE_MEMORY.host_status, "It seems that processed memory change operation killed the node"
            ),
            action=_escalate_to_final_profile,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.REBOOT),
            reason=limit_reached("max_host_reboots", Operation.REBOOT),
            action=escalate_to_profile,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.PROFILE) & _is_final & ~_is_dangerous_highload,
            reason=automatic_profile_not_supported,
            action=escalate_to_deactivate,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.PROFILE) & ~_is_final & ~_is_dangerous_highload,
            reason=automatic_profile_not_supported,
            action=escalate_to_redeploy,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.PROFILE) & ~_is_final & ~_is_dangerous_highload,
            reason=limit_reached("max_host_profiles", Operation.PROFILE),
            action=escalate_to_redeploy,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.REDEPLOY),
            reason=operation_repeated(Operation.REDEPLOY),
            action=escalate_to_dangerous_highload_test,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.REDEPLOY),
            reason=task_has_not_helped(
                Operation.REDEPLOY.host_status, "Redeploying hasn't helped", "Host failed to redeploy"
            ),
            action=escalate_to_dangerous_highload_test,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.PROFILE) & ~_is_final & _is_dangerous_highload,
            reason=limit_reached(
                "max_host_dangerous_highload_profiles",
                Operation.PROFILE,
                params={"modes": ProfileMode.DANGEROUS_HIGHLOAD_TEST},
            ),
            action=final_escalation,
        ),
        EscalationPoint(
            predicate=_is_2nd_time_node, reason=_2tn_disabled_for_bad_projects, action=escalate_to_deactivate
        ),
        EscalationPoint(predicate=_is_2nd_time_node, reason=dns_automation_off, action=escalate_to_deactivate),
    )
    return escalation_rules
