"""This stage wraps around scheduler stage and implements NOC-maintenance workflow"""
import logging
import typing as tp

from sepelib.core.exceptions import LogicalError
from walle import audit_log
from walle.authorization import ISSUER_WALLE
from walle.clients import juggler as juggler_client
from walle.hosts import Host, HostState, HostStatus
from walle.models import timestamp
from walle.scenario import maintenance
from walle.scenario.constants import (
    SOFT_NOC_WORK_TIMEOUT,
    NOC_SOFT_WORK_TIMEOUT_JUGGLER_SERVICE_NAME,
    ScenarioWorkStatus,
    ALL_TERMINATION_WORK_STATUSES,
    WORK_STATUS_LABEL_NAME,
    HostScenarioStatus,
    StageName,
    ScenarioFsmStatus,
)
from walle.scenario.marker import Marker
from walle.scenario.mixins import Stage
from walle.scenario.scenario import Scenario
from walle.scenario.stage.scheduler_stage import HostSchedulerStage
from walle.scenario.stage_info import StageAction, StageInfo, StageRegistry
from walle.scenario.stages import ScenarioRootStage

log = logging.getLogger(__name__)


@StageRegistry.register(StageName.NocMaintenanceStage)
class NocMaintenanceStage(ScenarioRootStage):
    STAGE_END_DEADLINE = "stage_end_deadline"
    WORK_END_DEADLINE = "work_end_deadline"

    def __init__(
        self, children, execution_timeout: tp.Optional[int] = None, work_timeout: int = SOFT_NOC_WORK_TIMEOUT, **params
    ):
        children = _ensure_only_child_is_a_scheduler(children)
        self.execution_timeout = execution_timeout
        self.work_timeout = work_timeout
        super().__init__(children, execution_timeout=execution_timeout, work_timeout=work_timeout, **params)

    def run(self, stage_info: StageInfo, scenario: Scenario):
        if self._is_scenario_ready(scenario):
            self._work_prepare(stage_info)
            if self._work_timed_out(stage_info):
                juggler_client.send_event(
                    NOC_SOFT_WORK_TIMEOUT_JUGGLER_SERVICE_NAME,
                    juggler_client.JugglerCheckStatus.CRIT,
                    _noc_timeout_message(
                        stage_info.data[self.WORK_END_DEADLINE] - self.work_timeout,
                        self.work_timeout,
                        scenario,
                    ),
                )
            if scenario.work_completed_by_workmate():
                scenario.set_works_status_label(ScenarioWorkStatus.FINISHING)

            return Marker.in_progress(message="Waiting for NOC works to end")

        self._prepare(stage_info)

        # Check if NOC maintenance or scenario is cancelled.
        for cancel_check_func in [_work_terminated, _scenario_canceled]:
            need_cancel, message = cancel_check_func(scenario)
            if need_cancel:
                self._handle_scenario_cancel(scenario)
                return Marker.success(message=message)

        if not self._is_scenario_ready(scenario) and self._timed_out(stage_info):
            fqdns = _collect_guilty_hosts(scenario)
            if fqdns:
                reason = _cancel_message_from_timeout(self.execution_timeout, fqdns)
                _reject_works_if_not_ready_yet(scenario, reason)
                scenario.clean_hosts_stage_info()
                return Marker.success(message=reason)

        scheduler_action = self._get_scheduler_action(stage_info)
        super().run(stage_info, scenario)

        self._allow_works_if_hosts_ready(scheduler_action, scenario)

        return Marker.in_progress(message="Waiting for child stages")

    @staticmethod
    def _handle_scenario_cancel(scenario):
        scenario.clean_hosts_stage_info()
        juggler_client.send_event(
            NOC_SOFT_WORK_TIMEOUT_JUGGLER_SERVICE_NAME,
            juggler_client.JugglerCheckStatus.OK,
            (
                f"NOC maintenance stage stage for scenario "
                f"https://wall-e.yandex-team.ru/scenarios/{scenario.scenario_id} terminated successfully"
            ),
        )

    def _timed_out(self, stage_info: StageInfo):
        return self.execution_timeout is not None and timestamp() >= stage_info.data[self.STAGE_END_DEADLINE]

    def _prepare(self, stage_info: StageInfo):
        if self.execution_timeout is not None and self.STAGE_END_DEADLINE not in stage_info.data:
            stage_info.data[self.STAGE_END_DEADLINE] = timestamp() + self.execution_timeout

    def _work_prepare(self, stage_info: StageInfo):
        if self.work_timeout is not None and self.WORK_END_DEADLINE not in stage_info.data:
            stage_info.data[self.WORK_END_DEADLINE] = timestamp() + self.work_timeout

    def _work_timed_out(self, stage_info: StageInfo):
        return self.work_timeout is not None and timestamp() >= stage_info.data[self.WORK_END_DEADLINE]

    @staticmethod
    def _is_scenario_ready(scenario: Scenario):
        return scenario.labels[WORK_STATUS_LABEL_NAME] == ScenarioWorkStatus.READY

    def _allow_works_if_hosts_ready(self, scheduler_action, scenario: Scenario):
        if self._is_scenario_ready(scenario):
            return

        if scheduler_action != StageAction.CHECK:
            # well, this leaked...
            # need to differentiate between "haven't tried yet"
            # and "tried at least once and there is nothing to acquire".
            # May be add extra host status? Like 'scheduled' vs 'queued'
            # where hosts from current group become be 'queued'.
            return

        if len(_currently_processing_hosts(scenario.hosts)) > 0:
            return

        _allow_works_if_not_cancelled(scenario, "All available hosts are ready for works.")

    def _get_scheduler_action(self, stage_info):
        active_stage, active_stage_info = self.iteration_strategy.get_active_stage(stage_info, self.children)
        return active_stage_info.action_type


@StageRegistry.register(StageName.FinishNocMaintenanceStage)
class FinishNocMaintenanceStage(Stage):
    """Set finished or canceled label when all is done."""

    def run(self, stage_info, scenario: Scenario):
        label_value = scenario.labels[WORK_STATUS_LABEL_NAME]

        good_to_finish_statuses = {
            ScenarioWorkStatus.REJECTED,
            ScenarioWorkStatus.CANCELED,
            ScenarioWorkStatus.FINISHED,
        }

        msg = None
        if label_value == ScenarioWorkStatus.CANCELING:
            msg = f"NOC maintenance scenario {scenario.scenario_id} has been canceled"
            log.info(msg)
            scenario.labels[WORK_STATUS_LABEL_NAME] = ScenarioWorkStatus.CANCELED

        elif label_value == ScenarioWorkStatus.FINISHING:
            msg = f"NOC maintenance scenario {scenario.scenario_id} has been finished"
            log.info(msg)
            scenario.labels[WORK_STATUS_LABEL_NAME] = ScenarioWorkStatus.FINISHED

        elif label_value not in good_to_finish_statuses:
            msg = f"Finishing NOC maintenance scenario {scenario.scenario_id} from incorrect status '{label_value}'"
            log.error(msg)
            scenario.labels[WORK_STATUS_LABEL_NAME] = ScenarioWorkStatus.FINISHED

        return Marker.success(message=msg)


def _ensure_only_child_is_a_scheduler(child):
    if isinstance(child, list):
        if len(child) != 1:
            raise LogicalError

        return _ensure_only_child_is_a_scheduler(*child)

    if isinstance(child, HostSchedulerStage):
        return [child]
    else:
        raise LogicalError


def _work_terminated(scenario: Scenario) -> (bool, tp.Optional[str]):
    terminal_statuses = set(ALL_TERMINATION_WORK_STATUSES)
    if WORK_STATUS_LABEL_NAME in scenario.labels and scenario.labels[WORK_STATUS_LABEL_NAME] in terminal_statuses:
        return True, _get_message_from_label_status(scenario)
    return False, None


def _scenario_canceled(scenario: Scenario) -> (bool, tp.Optional[str]):
    if scenario.status == ScenarioFsmStatus.CANCELING:
        return True, "Scenario was canceled"
    return False, None


def _get_message_from_label_status(scenario: Scenario):
    return "NOC maintenance has been {}".format(scenario.labels[WORK_STATUS_LABEL_NAME])


def _collect_guilty_hosts(scenario: Scenario) -> list[str]:
    fqdns = []
    for host_info in scenario.hosts.values():
        if host_info.is_acquired and host_info.status in HostScenarioStatus.ALL_SCHEDULED:
            host = Host.get_by_inv(host_info.inv)
            if host.state != HostState.MAINTENANCE or host.status != HostStatus.MANUAL:
                fqdns.append(host.name)
    return fqdns


def _cancel_message_from_timeout(execution_time, fqdns):
    return "Not all available hosts are ready for works after required timeout: {}, guilty hosts: {}".format(
        execution_time, ", ".join(fqdns)
    )


def _currently_processing_hosts(host_infos):
    return [h for h in host_infos.values() if h.status in HostScenarioStatus.ALL_SCHEDULED]


def _allow_works_if_not_cancelled(scenario: Scenario, reason):
    _set_maintenance_status(scenario, ScenarioWorkStatus.READY, [ScenarioWorkStatus.STARTED], reason)


def _reject_works_if_not_ready_yet(scenario: Scenario, reason):
    if scenario.labels[WORK_STATUS_LABEL_NAME] in {ScenarioWorkStatus.READY, ScenarioWorkStatus.REJECTED}:
        return

    _set_maintenance_status(scenario, ScenarioWorkStatus.REJECTED, [ScenarioWorkStatus.STARTED], reason)


def _set_maintenance_status(scenario: Scenario, status, allowed_statuses, reason):
    maintenance.check_status_transition(scenario, status, allowed_statuses)

    audit_log.on_set_maintenance_status(ISSUER_WALLE, scenario.scenario_id, status, reason=reason).complete()
    scenario.set_works_status_label(status)


def _noc_timeout_message(start_time: int, timeout: int, scenario: Scenario) -> str:
    message = """NOC maintenance stage is not finished before timeout {} from start time {}
Scenario: https://wall-e.yandex-team.ru/scenarios/{}""".format(
        timeout, start_time, scenario.scenario_id
    )
    if scenario.ticket_key is not None:
        message = message + "\nTicket: https://st.yandex-team.ru/{}".format(scenario.ticket_key)
    return message
