"""Finite-state machine for scenario."""

import contextlib
import logging

import mongoengine

import walle.scenario.script  # noqa

# Load stage processing handlers
# noinspection PyUnresolvedReferences
from sepelib.core import config
from sepelib.core.exceptions import Error
from walle import audit_log
from walle.base_fsm import BaseFsm
from walle.locks import ScenarioInterruptableLock, ScenarioAcquiringInterruptableGlobalLock
from walle.models import timestamp
from walle.scenario.constants import (
    ScenarioFsmStatus,
    ScenarioWorkStatus,
    ScriptName,
    ALL_CANCELATION_WORK_STATUSES,
    ALL_FSM_RUNNABLE_SCENARIO_STATUSES,
    WORK_STATUS_LABEL_NAME,
)
from walle.scenario.error_handlers import scenario_root_stage_error_handler
from walle.scenario.marker import MarkerStatus
from walle.scenario.scenario import Scenario
from walle.stats import stats_manager as stats
from walle.util.misc import StopWatch, DummyContextManager
from walle.application import app

log = logging.getLogger(__name__)


NEXT_CHECK_ADVANCE_TIME = 3
"""
An amount of time on which we defer next check every time when attempt to lock a scenario to prevent attempts to lock this
scenario by other FSM.
"""

_CHECK_INTERVAL = 15
"""Maximum task check interval.

Note: mocked in tests.
"""


def _commit_scenario_changes(scenario):
    revision = scenario.revision
    scenario.revision += 1
    scenario.update_stage_info_hosts_for_all_child_stages()
    try:
        scenario.save(save_condition={"revision": revision})
    except mongoengine.OperationError:
        raise Error("Unable to commit scenario change: it doesn't have an expected state already.")


class ScenarioFsm(BaseFsm):
    """Finite-state machine for scenarios."""

    _name = "Scenario FSM"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._ids_in_process = set()
        self._on_handbrake = _on_handbrake()

    def _state_machine(self):
        log.debug("Checking scenario tasks...")

        try:
            self._process_scenarios()
            self._next_check = timestamp() + _CHECK_INTERVAL
        except Exception:
            log.exception("%s has crashed:", self._name)

            self._next_check = timestamp() + _CHECK_INTERVAL
            log.error("%s will be revived in %s seconds.", self._name, _CHECK_INTERVAL)

    def _process_scenarios(self):
        query = mongoengine.Q()
        total_shards_count = config.get_value("task_processing.shards_num")
        # TODO(rocco66): extract some common code after WALLE-3967
        shards = self._partitioner.get_numeric_shards(total_shards_count)
        for shard in shards:
            query |= mongoengine.Q(scenario_id__mod=(total_shards_count, int(shard.id)))
        if not shards:
            log.warning("There is no shards for scenario fsm")
            return

        with contextlib.ExitStack() as stack:
            for shard in shards:
                if self._stop:
                    return
                stack.enter_context(shard.lock)

            exclude_ids = list(self._ids_in_process)
            query &= mongoengine.Q(
                scenario_id__nin=exclude_ids,
                next_check_time__lte=timestamp(),
                status__in=ALL_FSM_RUNNABLE_SCENARIO_STATUSES,
            )
            for scenario in list(Scenario.objects(query).only("scenario_id")):
                if self._stop:
                    return
                if self._pool.full():
                    log.error("%s has reached maximum concurrency limit (%s).", self._name, self._pool.size)
                    return
                Scenario.objects(scenario_id=scenario.scenario_id).modify(
                    set__next_check_time=timestamp() + NEXT_CHECK_ADVANCE_TIME
                )
                self._process_scenario_async(scenario.scenario_id)

    def _process_scenario_async(self, scenario_id):
        log.debug("Spawn processing greenlet for scenario with id: #%s.", scenario_id)
        greenlet = self._pool.spawn(_run_scenario, scenario_id)
        self._ids_in_process.add(scenario_id)
        greenlet.link(lambda greenlet: self._on_scenario_processed(scenario_id, greenlet))

    def _on_scenario_processed(self, scenario_id, greenlet):
        try:
            self._ids_in_process.remove(scenario_id)

            if greenlet.get(block=False):
                try:
                    scenario = Scenario.objects.only("next_check_time").get(scenario_id=scenario_id)
                except mongoengine.DoesNotExist:
                    pass
                else:
                    if scenario is not None and scenario.next_check_time < self._next_check:
                        self._next_check = scenario.next_check_time
                        self._main_event.set()
        except Exception:
            log.exception("Processing greenlet for scenario #%s has crashed:", scenario_id)
        else:
            log.debug("Processing greenlet for scenario #%s has finished its work.", scenario_id)


def _on_handbrake():
    settings = app.settings()
    return settings.scenario_fsm_handbrake is not None and settings.scenario_fsm_handbrake.timeout_time >= timestamp()


def _run_scenario(scenario_id):
    with ScenarioInterruptableLock(scenario_id):
        scenario = Scenario.objects.get(scenario_id=scenario_id)
        if scenario is None or scenario.status == ScenarioFsmStatus.CANCELED:
            return

        if _on_handbrake():
            log.info("Scenario %s has it's emergency brake engaged.", scenario.scenario_id)
            return

        log.debug("Start processing of scenario -> id: %s name: %s", scenario.scenario_id, scenario.name)

        if scenario.get_works_status() == ScenarioWorkStatus.ACQUIRING_PERMISSION:
            global_scenario_lock = ScenarioAcquiringInterruptableGlobalLock
        else:
            global_scenario_lock = DummyContextManager

        with global_scenario_lock():
            main_stopwatch = StopWatch()
            root_stage = scenario.stage_info.deserialize()

            with scenario_root_stage_error_handler(scenario.stage_info, scenario):
                scenario_root_stage_call_method = _get_scenario_root_stage_call_method(scenario)
                scenario_root_stage_call_method(scenario, scenario_id, root_stage)

            log.info("Going to commit stage info")
            _commit_scenario_changes(scenario)
            stats.add_sample(("scenario", "total_processing_time"), main_stopwatch.get())

    log.debug("Scenario: %s, executor has finished", scenario.scenario_id)


def _call_run(scenario, scenario_id, root_stage):
    log.info("Calling 'run'")
    marker = root_stage.run(scenario.stage_info, scenario)
    scenario.message = marker.message
    if marker.status == MarkerStatus.SUCCESS:
        with audit_log.on_finish_scenario(scenario.issuer, scenario_id):
            scenario.mark_as_finished()
            scenario.dismiss_hosts()


def _call_cleanup(scenario, scenario_id, root_stage):
    log.info("Calling 'cleanup'")
    marker = root_stage.cleanup(scenario.stage_info, scenario)

    if marker.status == MarkerStatus.SUCCESS:
        with audit_log.on_cancel_scenario(scenario.issuer, scenario_id):
            scenario.mark_as_cancelled()
            scenario.dismiss_hosts()

    scenario.message = marker.message


def _work_cancelled(scenario):
    return scenario.labels.get(WORK_STATUS_LABEL_NAME) in ALL_CANCELATION_WORK_STATUSES


def _cleanup_on_work_cancel_enabled(scenario: Scenario):
    return scenario.scenario_type != ScriptName.NOC_SOFT


def _scenario_cancelled(scenario):
    return scenario.status == ScenarioFsmStatus.CANCELING


def _cleanup_on_scenario_cancel_enabled(scenario: Scenario):
    return scenario.scenario_type != ScriptName.NOC_SOFT


def _get_scenario_root_stage_call_method(scenario):
    if _work_cancelled(scenario) and not _cleanup_on_work_cancel_enabled(scenario):
        return _call_run

    if _scenario_cancelled(scenario) and not _cleanup_on_scenario_cancel_enabled(scenario):
        return _call_run

    return _call_cleanup if _work_cancelled(scenario) or _scenario_cancelled(scenario) else _call_run
