"""Rules for hw_watcher gpu check."""

import logging

from walle import restrictions
from walle.admin_requests.constants import RequestTypes
from walle.clients.eine import ProfileMode
from walle.expert.decision import Decision
from walle.expert.failure_types import FailureType
from walle.expert.rules.base import CheckRuleInterface
from walle.expert.rules.escalation import (
    EscalationRules,
    EscalationPoint,
    action_match,
    limit_reached,
    escalate_to_report,
    task_has_not_helped,
)
from walle.expert.rules.hw_watcher_rules.util import (
    get_eine_code,
    get_reason_from_hw_watcher,
    operation_match,
    get_hw_watcher_result,
)
from walle.expert.rules.utils import repair_hardware_params
from walle.expert.types import WalleAction, CheckType
from walle.operations_log.constants import Operation
from walle.util.misc import drop_none

log = logging.getLogger(__name__)


GPU_POWER_UNKNOWN_MAX_COUNT = 10


class GPU_EINE_CODE:
    GPU_OVERHEAT = "GPU_OVERHEAT"
    GPU_POWER_CAPPING = "GPU_POWER_CAPPING"
    GPU_CAPPING = "GPU_CAPPING"
    GPU_BANDWIDTH_TOO_LOW = "GPU_BANDWIDTH_TOO_LOW"
    GPU_MISSING = "GPU_MISSING"
    GPU_RETIRED_PAGES = "GPU_RETIRED_PAGES"
    GPU_RETIRED_PAGES_PENDING = "GPU_RETIRED_PAGES_PENDING"
    GPU_P2P_FAILED = "GPU_P2P_FAILED"
    GPU_POWER_UNKNOWN = "GPU_POWER_UNKNOWN"
    GPU_HANG = "GPU_HANG"
    GPU_INFOROM_CORRUPTED = "GPU_INFOROM_CORRUPTED"


def _get_failure_type_request_type(hw_watcher_result):
    eine_code = get_eine_code(hw_watcher_result)

    if not eine_code:
        return None, None

    if GPU_EINE_CODE.GPU_OVERHEAT in eine_code:
        return FailureType.GPU_OVERHEAT, RequestTypes.GPU_OVERHEATED.type
    if GPU_EINE_CODE.GPU_POWER_CAPPING in eine_code:
        return FailureType.GPU_POWER_CAPPING, RequestTypes.GPU_POWER_CAPPING.type
    if GPU_EINE_CODE.GPU_CAPPING in eine_code:
        return FailureType.GPU_CAPPING, RequestTypes.GPU_CAPPING.type
    if GPU_EINE_CODE.GPU_BANDWIDTH_TOO_LOW in eine_code:
        return FailureType.GPU_BANDWIDTH_TOO_LOW, RequestTypes.GPU_BANDWIDTH_TOO_LOW.type
    if GPU_EINE_CODE.GPU_MISSING in eine_code:
        return FailureType.GPU_MISSING, RequestTypes.GPU_MISSING.type
    if GPU_EINE_CODE.GPU_RETIRED_PAGES in eine_code:
        return FailureType.GPU_RETIRED_PAGES, RequestTypes.GPU_RETIRED_PAGES.type
    if GPU_EINE_CODE.GPU_INFOROM_CORRUPTED in eine_code:
        return FailureType.GPU_INFOROM_CORRUPTED, RequestTypes.GPU_INFOROM_CORRUPTED.type
    if GPU_EINE_CODE.GPU_HANG in eine_code:
        return FailureType.GPU_HANG, RequestTypes.GPU_HANG.type
    if GPU_EINE_CODE.GPU_RETIRED_PAGES_PENDING in eine_code:
        return FailureType.GPU_RETIRED_PAGES_PENDING, RequestTypes.GPU_RETIRED_PAGES_PENDING.type
    if GPU_EINE_CODE.GPU_P2P_FAILED in eine_code:
        return FailureType.GPU_P2P_FAILED, RequestTypes.GPU_P2P_FAILED.type
    if GPU_EINE_CODE.GPU_POWER_UNKNOWN in eine_code:
        return FailureType.GPU_POWER_UNKNOWN, RequestTypes.GPU_POWER_UNKNOWN.type

    return None, None


def escalate_gpu_failed_to_repair_hardware(decision, reason) -> Decision:
    params = decision.params
    decision.params["operation"] = Operation.REPAIR_GPU.type
    return Decision(
        WalleAction.REPAIR_HARDWARE,
        reason=reason,
        params=params,
        checks=decision.checks,
        failure_type=decision.failure_type,
        restrictions=[restrictions.AUTOMATED_GPU_REPAIR],
    )


class CheckGpu(CheckRuleInterface):
    check_type = CheckType.GPU

    escalation_rules = EscalationRules(
        # we can not check that previous task was for the same GPU, currently
        # So we do not forbid to create new task from the previous one.
        # Use limits to ensure we don't DDoS ITDC guys.
        EscalationPoint(
            predicate=action_match(WalleAction.REPAIR_HARDWARE) & operation_match(Operation.REPAIR_GPU),
            reason=limit_reached("max_host_gpu_failures", Operation.REPAIR_GPU),
            action=escalate_to_report,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.PROFILE),
            reason=limit_reached("max_host_profiles", Operation.PROFILE, params={"modes": ProfileMode.DEFAULT}),
            action=escalate_gpu_failed_to_repair_hardware,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.PROFILE),
            reason=task_has_not_helped("profile didn't help", Operation.PROFILE),
            action=escalate_gpu_failed_to_repair_hardware,
        ),
        EscalationPoint(
            predicate=action_match(WalleAction.REBOOT),
            reason=limit_reached("max_host_reboots", Operation.REBOOT),
            action=escalate_gpu_failed_to_repair_hardware,
        ),
    )

    def apply(self, host, check_result):
        hw_watcher_result = get_hw_watcher_result(check_result)

        failure_type, request_type = _get_failure_type_request_type(hw_watcher_result)
        reason = get_reason_from_hw_watcher("GPU", hw_watcher_result["reason"])

        # profile decision for gpu power unknown reason
        if failure_type == FailureType.GPU_POWER_UNKNOWN:
            oplog_params = drop_none(
                dict(
                    slots=hw_watcher_result.get("slots"),
                    failure_type=FailureType.GPU_POWER_UNKNOWN.name,
                    request_type=request_type,
                )
            )
            return self._mk_profile_decision(failure_type, reason, oplog_params, check_result)

        if failure_type in [FailureType.GPU_RETIRED_PAGES_PENDING, FailureType.GPU_P2P_FAILED, FailureType.GPU_HANG]:
            oplog_params = drop_none(
                dict(slots=hw_watcher_result.get("slots"), failure_type=failure_type, request_type=request_type)
            )
            return self._mk_reboot_decision(failure_type, reason, oplog_params, check_result)

        if request_type is not None:
            return self._mk_repair_decision(failure_type, request_type, hw_watcher_result, reason, check_result)

        # TODO: this branch should not exist, check is in CRIT status.
        # But there is no BOT admin request type for this, it is delayed.
        # https://st.yandex-team.ru/DCA-3885
        reason = get_reason_from_hw_watcher("gpu", hw_watcher_result["reason"])
        return Decision.healthy("Unsupported GPU problem, ignoring: {}".format(reason))

    def _mk_repair_decision(self, failure_type, request_type, hw_watcher_result, reason, failure_check_info):
        params = repair_hardware_params(
            request_type=request_type,
            operation_type=Operation.REPAIR_GPU.type,
            eine_code=get_eine_code(hw_watcher_result),
            errors=hw_watcher_result["reason"],
            slot=hw_watcher_result.get("slot"),
            serial=hw_watcher_result.get("serial"),
        )
        return Decision(
            WalleAction.REPAIR_HARDWARE,
            reason=reason,
            params=params,
            checks=[self.check_type],
            failure_type=failure_type,
            restrictions=[restrictions.AUTOMATED_GPU_REPAIR],
            failure_check_info=failure_check_info,
        )

    def _mk_reboot_decision(self, failure_type, reason, oplog_params, failure_check_info):
        return Decision(
            WalleAction.REBOOT,
            reason=reason,
            checks=[self.check_type],
            params=oplog_params,
            failure_type=failure_type,
            restrictions=[restrictions.AUTOMATED_GPU_REPAIR, restrictions.AUTOMATED_REBOOT],
            failure_check_info=failure_check_info,
        )

    def _mk_profile_decision(self, failure_type, reason, oplog_params, failure_check_info):
        return Decision(
            WalleAction.PROFILE,
            reason=reason,
            checks=[self.check_type],
            params=dict({"profile_mode": ProfileMode.DEFAULT}, **oplog_params),
            failure_type=failure_type,
            restrictions=[restrictions.AUTOMATED_GPU_REPAIR, restrictions.AUTOMATED_PROFILE],
            failure_check_info=failure_check_info,
        )
