"""Changes a corrupted disk."""

import copy
import logging

import walle.admin_requests.request as admin_requests
import walle.tasks
import walle.util.tasks
from sepelib.core.exceptions import LogicalError
from walle import audit_log
from walle import restrictions
from walle.authorization import ISSUER_WALLE
from walle.expert import dmc
from walle.expert.automation import healing_automation
from walle.expert.decision import Decision
from walle.expert.types import WalleAction, CheckType
from walle.fsm_stages import hw_errors
from walle.fsm_stages.common import (
    register_stage,
    commit_stage_changes,
    complete_current_stage,
    complete_parent_stage,
    get_parent_stage,
    get_stage_by_uid,
    get_current_stage,
)
from walle.fsm_stages.constants import StageStatus
from walle.operations_log.constants import Operation
from walle.stages import Stages

log = logging.getLogger(__name__)


class ChangeDiskStageHandler(hw_errors.HwErrorsStageHandler):
    operation = "disk change"
    check = CheckType.DISK

    def handle_action(self, decision):
        host = self._host

        if decision.action == WalleAction.CHANGE_DISK:
            if "slot" not in decision.params and "serial" not in decision.params:
                _upgrade_task(
                    host, decision, "The task is upgraded to task which fixes disk errors via host profiling."
                )
            elif decision.params["redeploy"] and not self._stage.get_param("redeploy"):
                _upgrade_task(
                    host, decision, "The task is upgraded to task which redeploys the host after changing the disk."
                )
            else:
                self._stage.set_temp_data("decision_params", decision.params)
                self._stage.set_temp_data("decision_reason", decision.reason)
                self._stage.set_temp_data("redeploy", decision.params["redeploy"])

                self.handle_create()
        elif decision.action == WalleAction.DEACTIVATE:
            dmc.handle_host_deactivation(healing_automation(host.project), host, decision.reason)
        elif decision.action == WalleAction.REBOOT:
            walle.tasks.schedule_reboot(
                ISSUER_WALLE,
                walle.tasks.TaskType.AUTOMATED_HEALING,
                host,
                decision=decision,
                reason=decision.reason,
                from_current_task=True,
            )
        elif decision.action == WalleAction.REDEPLOY:
            walle.tasks.schedule_redeploy(
                ISSUER_WALLE,
                walle.tasks.TaskType.AUTOMATED_HEALING,
                host,
                decision=decision,
                reason=decision.reason,
                from_current_task=True,
            )
        else:
            raise LogicalError()

    def handle_create(self):
        host = self._host
        stage = self._stage

        # TODO: We need stages to be more dynamic here
        log_operation_stage = get_stage_by_uid(host, stage.get_param("log_operation_stage_uid"))
        if log_operation_stage.name != Stages.LOG_COMPLETED_OPERATION:
            raise LogicalError()

        decision_params = stage.get_temp_data("decision_params")
        reason = stage.get_temp_data("decision_reason")

        if decision_params.get("slot") is not None:
            request_type = admin_requests.RequestTypes.CORRUPTED_DISK_BY_SLOT
        else:
            request_type = admin_requests.RequestTypes.CORRUPTED_DISK_BY_SERIAL

        request_id, ticket_id = admin_requests.create_admin_request(host, request_type, reason, **decision_params)
        stage.set_temp_data(hw_errors.REQUEST_ID_STAGE_FIELD_NAME, request_id)
        stage.set_temp_data(hw_errors.TICKET_ID_STAGE_FIELD_NAME, ticket_id)

        params = (
            {"slot": decision_params["slot"]} if "slot" in decision_params else {"serial": decision_params["serial"]}
        )
        log_operation_stage.set_param("operation", Operation.CHANGE_DISK.type)
        log_operation_stage.set_param("params", params)

        commit_stage_changes(host, status=StageStatus.HW_ERRORS_WAITING_DC, check_now=True)

    def handle_completed(self):
        if self._stage.get_temp_data("redeploy"):
            # If redeploy is needed, complete the stage and go to redeploy stages.
            complete_current_stage(self._host)
        else:
            # If redeploy is not needed, complete the parent disk healing stage,
            # skipping the following redeploy stages.
            parent_stage = get_parent_stage(self._host)

            if parent_stage is not None and parent_stage.name == Stages.HEAL_DISK:
                complete_parent_stage(self._host)
            else:
                raise LogicalError()


def _upgrade_task(host, decision, reason):
    log.warning("%s: %s", host.human_id(), reason)

    try:
        restrictions.check_restrictions(host, decision.get_restrictions())
    except restrictions.OperationRestrictedError:
        automation = healing_automation(host.project)
        return dmc.handle_restricted_automated_action(automation, host, decision.reason, cancel=reason)

    prev_host = host.copy()
    walle.tasks.schedule_disk_change(
        ISSUER_WALLE, walle.tasks.TaskType.AUTOMATED_HEALING, host, decision, decision.reason, from_current_task=True
    )
    walle.util.tasks.on_task_cancelled(ISSUER_WALLE, prev_host, reason)


def restart_task_with_host_power_off(host, scenario):
    stage = get_current_stage(host)

    # WALLE-4057 Check for `orig_decision` in task stage's `data`, if not present - look for it in `params[data]`.
    orig_decision_kwargs = copy.deepcopy(stage.get_data("orig_decision", default=None))
    if orig_decision_kwargs is None and stage.has_param("data"):
        orig_decision_kwargs = stage.get_param("data").get("orig_decision", None)
    if orig_decision_kwargs is None:
        raise LogicalError()

    orig_decision_params = orig_decision_kwargs.get("params")
    if not orig_decision_params:
        orig_decision_kwargs["params"] = orig_decision_params = {}

    orig_decision_params["power_off"] = True
    request_id = stage.get_temp_data(hw_errors.REQUEST_ID_STAGE_FIELD_NAME, None)
    if request_id:
        orig_decision_params[hw_errors.REQUEST_ID_STAGE_FIELD_NAME] = request_id
    ticket_id = stage.get_temp_data(hw_errors.TICKET_ID_STAGE_FIELD_NAME, None)
    if ticket_id:
        orig_decision_params[hw_errors.TICKET_ID_STAGE_FIELD_NAME] = ticket_id

    decision = Decision(**orig_decision_kwargs)
    reason = stage.get_param("reason", None)
    restarted_by_scenario_reason = "Restarted by scenario {}".format(scenario.id)
    if reason:
        decision.reason = "{} ({})".format(reason, restarted_by_scenario_reason)
    else:
        decision.reason = restarted_by_scenario_reason

    prev_host = host.copy()
    walle.tasks.schedule_disk_change(
        ISSUER_WALLE,
        host.task.type,
        host,
        decision,
        decision.reason,
        from_current_task=True,
        audit_log_type=audit_log.on_power_off_host,
    )
    walle.util.tasks.on_task_cancelled(ISSUER_WALLE, prev_host, reason)


# Attention:
# The stage assumes that it has a parent Stages.HEAL_DISK stage and next deploy stages and completes parent stage when
# it's needed to skip deploy stages.
register_stage(Stages.CHANGE_DISK, ChangeDiskStageHandler.as_handler(), initial_status=StageStatus.HW_ERRORS_PENDING)
