"""Finite-state machine for long-running host tasks."""

import logging
import random

import mongoengine

# Load stage processing handlers
# noinspection PyUnresolvedReferences
import walle.fsm_stages.handlers  # noqa
from sepelib.core import config
from sepelib.core.exceptions import Error
from walle import projects
from walle.authorization import ISSUER_WALLE
from walle.base_fsm import BaseFsm
from walle.clients import deploy, ipmiproxy, dmc as dmc_client
from walle.expert import decisionmakers
from walle.expert import dmc
from walle.expert.automation import AutomationDisabledError, healing_automation
from walle.expert.types import WalleAction
from walle.fsm_stages import common, ipmi_errors
from walle.fsm_stages.common import (
    get_current_stage,
    switch_to_stage,
    cancel_task,
    fail_current_stage,
    enter_task,
    cancel_task_considering_operation_limits,
)
from walle.hosts import Host, HostStatus, TaskType, NonIpmiHostError, get_raw_query_for_dmc_rules
from walle.locks import HostInterruptableLock
from walle.models import timestamp
from walle.stages import Stages
from walle.statbox.contexts import exception_context, host_context
from walle.statbox.loggers import fsm_logger
from walle.stats import stats_manager as stats
from walle.util import mongo
from walle.util.misc import StopWatch
from walle.util.mongo import lock as mongo_lock

log = logging.getLogger(__name__)

_CHECK_INTERVAL = 15
"""Maximum task check interval.

Note: mocked in tests.
"""


class HostFsm(BaseFsm):
    """Finite-state machine for long-running host tasks."""

    _name = "Host finite-state machine"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._invs_in_process = set()

    def stop(self):
        log.info("Invs in process: %s", self._invs_in_process)
        super().stop()

    def _state_machine(self):
        log.debug("Checking host tasks...")

        try:
            self._process_hosts()
            self._next_check = timestamp() + _CHECK_INTERVAL
        except Exception:
            log.exception("%s has crashed:", self._name)

            self._next_check = timestamp() + _CHECK_INTERVAL
            log.error("%s will be revived in %s seconds.", self._name, _CHECK_INTERVAL)

    def _process_hosts(self):
        shards = self._partitioner.get_numeric_shards(self._total_shards_count)
        random.shuffle(shards)
        for shard in shards:
            with shard.lock:
                shard_query = mongo.get_host_mongo_shard_query(shard, self._total_shards_count)
                while True:
                    if self._stop:
                        return
                    if self._pool.full():
                        log.error("%s has reached maximum concurrency limit (%s).", self._name, self._pool.size)
                        return

                    dmc_raw_rules = get_raw_query_for_dmc_rules()

                    query = shard_query & mongoengine.Q(
                        inv__nin=list(self._invs_in_process),
                        status__in=HostStatus.ALL_TASK,
                        task__next_check__lte=timestamp(),
                    )
                    if host := (
                        Host.objects(query, __raw__=dmc_raw_rules)
                        .only("inv", "tier", "task.task_id")
                        .order_by("task.next_check")
                        .modify(set__task__next_check=timestamp() + common.NEXT_CHECK_ADVANCE_TIME)
                    ):
                        self._process_host_async(host.inv, host.uuid, host.tier, host.task.task_id)
                    else:
                        break

    def _process_host_async(self, inv, uuid, tier, task_id):
        log.debug("Spawn processing greenlet for host #%s.", inv)
        greenlet = self._pool.spawn(_process_host, inv, uuid, tier, task_id)
        self._invs_in_process.add(inv)
        greenlet.link(lambda greenlet: self._on_host_processed(inv))

    def _on_host_processed(self, inv):
        try:
            self._invs_in_process.remove(inv)
        except Exception:
            log.exception("Processing greenlet for host #%s has crashed:", inv)
        else:
            log.debug("Processing greenlet for host #%s has finished its work.", inv)


def _process_host(inv, uuid, tier, task_id):
    log.debug("Processing host #%s's %s task...", inv, task_id)
    stopwatch = StopWatch()

    try:
        lock = HostInterruptableLock(uuid, tier)
        if lock.acquire(blocking=False):
            try:
                _handle_host(
                    uuid, task_id, one_pass=False, suppress_no_task_error=True, suppress_internal_errors=(Exception,)
                )
            except mongo_lock.LockIsLostError as e:
                raise e.to_exception()
            finally:
                lock.release()
        else:
            log.debug("Cancel processing of host #%s: unable to lock the host.", inv)
    except Exception:
        log.exception("Processing host #%s's %s task has crashed:", inv, task_id)
        return False
    finally:
        stats.add_sample(("fsm", "one_host", "iteration_time"), stopwatch.get())
    return True


def _handle_host(uuid, task_id, one_pass=True, suppress_no_task_error=False, suppress_internal_errors=tuple()):
    # Note: Used in tests. Optional arguments are for unit tests only.

    host = _fetch_host(uuid, task_id, suppress_no_task_error)
    if host is None:
        return

    task_name = host.status
    old_stage_name = host.task.stage_name

    try:
        for _ in range(1 if one_pass else 10):
            project = _fetch_project(host.project)

            if _cancel_task_if_needed(host, project):
                break

            if _fsm_on_handbrake(project):
                break

            _handle_current_stage(host)

            if host.task is None or host.task.next_check > timestamp():
                break
        else:
            if not one_pass and old_stage_name == host.task.stage_name:
                raise Error("Possible infinite FSM loop detected.")
    except Exception as error:
        _handle_stage_handling_exception(host, error, task_name, suppress_internal_errors)


def _handle_stage_handling_exception(host, error, task_name, suppress_internal_errors):
    error_string = _exception_to_string(error)
    internal_error = True
    if host.task is not None and host.task.stage_name is not None:
        task_name += ":" + host.task.stage_name
    if isinstance(error, AutomationDisabledError):
        log.warning("%s: Failed to process '%s' task: %s", host.human_id(), task_name, error_string)
        internal_error = False
    elif isinstance(error, deploy.DeployClientError):
        log.error("%s: LUI/DHCP API error during processing '%s' task: %s", host.human_id(), task_name, error_string)
    else:
        (log.error if isinstance(error, Error) else log.exception)(
            "%s: {} during processing '%s' task: %s".format(type(error).__name__),
            host.human_id(),
            task_name,
            error_string,
        )
    if host.task is not None:
        _advance_host_next_check_time(host, error_string)
    if internal_error and not isinstance(error, suppress_internal_errors):
        raise


def _advance_host_next_check_time(host, error_string):
    # Advance next check time to not get into an infinite loop on this error
    error_next_check = timestamp() + common.ERROR_CHECK_PERIOD
    try:
        Host.objects(inv=host.inv, task__task_id=host.task.task_id, task__next_check__lt=error_next_check).update_one(
            unset__task__status_message=True, set__task__error=error_string, set__task__next_check=error_next_check
        )
    except Exception as e:
        log.error("Unable to advance host next check time: %s.", e)


def _exception_to_string(error):
    try:
        error_string = str(error)
    except UnicodeEncodeError:
        error_string = "[failed to format the error to string]"
    return error_string


def _handle_current_stage(host):
    if host.task.stage_uid is None:
        log.info("%s: Initiate processing of '%s' task...", host.human_id(), host.status)
        enter_task(host)

    stage = get_current_stage(host)
    if stage.stages:
        # descend into parent stage's children
        stage = switch_to_stage(host, stage)

    statbox = fsm_logger(walle_action="fsm_handle_current_stage", **host_context(host))

    log.info("%s: '%s' task: Handling '%s' stage...", host.human_id(), host.status, stage.description)
    statbox.log()

    try:
        common.handle_current_stage(host)
    except NonIpmiHostError as e:
        statbox.log(fsm_result="error", **exception_context())
        fail_current_stage(host, e)
    except ipmiproxy.IpmiHostMissingError as e:
        statbox.log(fsm_result="error", **exception_context())
        ipmi_errors.handle_ipmi_host_missing_error(host, e.ipmi_fqdn, e)
    except ipmiproxy.HostHwError as e:
        statbox.log(fsm_result="error", **exception_context())
        ipmi_errors.handle_ipmi_error(host, e)
    except BaseException:
        statbox.log(fsm_result="error", **exception_context())
        raise
    else:
        statbox.log(fsm_result="ok")


def _fetch_host(uuid, task_id, suppress_no_task_error=False):
    try:
        return Host.objects.get(uuid=uuid, task__task_id=task_id)
    except mongoengine.DoesNotExist:
        if suppress_no_task_error:
            log.warning("Cancel processing of #%s host: it's not under %s task already.", uuid, task_id)
            return None

        raise


def _fetch_project(project_id):
    return projects.get_by_id(project_id, fields=("healing_automation", "fsm_handbrake"))


def _fsm_on_handbrake(project):
    # don't wait for gc, check timestamp here
    if project.fsm_handbrake is not None and project.fsm_handbrake.timeout_time >= timestamp():
        log.info("Project %s has it's emergency brake engaged.", project.id)
        return True
    else:
        return False


def _cancel_task_if_needed(host, project):
    # TODO: Do something with this ugly hackaround
    # Cancel automated healing task if automation is disabled or host became healthy and we haven't started processing
    # the task yet. This may reduce the number of destructive actions that would be done when automation is disabled
    # automatically or when it was a temporary host unavailability and it became available before CMS allowed to process
    # the action.

    # we only cancel automated healing
    if _is_healing_task(host) and _can_cancel_task_at_this_stage(host):
        if _cancel_for_disabled_automation(host, project):
            return True

        _, decision = dmc_client.get_decisions_from_handler(host, decision_params=dmc_client.DecisionParams(set()))

        if _cancel_for_healthy_host(host, decision):
            return True

        if _cancel_for_higher_priority_failure(host, decision):
            return True

    return False


def _is_healing_task(host):
    return host.task.type == TaskType.AUTOMATED_HEALING and host.task.owner == ISSUER_WALLE


def _can_cancel_task_at_this_stage(host):
    return (
        # Pending task
        host.task.stage_uid is None
        or
        # Task that has processed only acquire permission stage and hasn't done anything destructive yet
        host.task.stage_name == Stages.ACQUIRE_PERMISSION
        and host.task.stage_uid == host.task.stages[0].uid
    )


def _cancel_for_disabled_automation(host, project):
    automation = healing_automation(host.project)
    try:
        automation.check_automation_settings(project)
    except AutomationDisabledError as e:
        cancel_task(host, str(e))
        return True


def _cancel_for_healthy_host(host, decision):
    if decision.action == WalleAction.HEALTHY:
        return cancel_task_considering_operation_limits(host)
    else:
        return False


def _has_several_check_type_priorities() -> bool:
    return len(config.get_value("expert_system.check_type_priorities")) > 1


def _cancel_for_higher_priority_failure(host: Host, new_decision: dmc.Decision) -> bool:
    if not _has_several_check_type_priorities():
        return False
    if not host.task or not host.task.decision:
        return False
    if host.may_cancel_task_for_higher_priority_failure():
        # NOTE(rocco66): we should not try to process failure with higher priority after this point
        # because that might be a consequence of healing process
        return False
    old_decision = dmc.Decision(**host.task.decision.to_dict())
    if new_decision.checks != old_decision.checks:
        new_checks_priority = decisionmakers.get_checks_priority(new_decision.checks)
        old_checks_priority = decisionmakers.get_checks_priority(old_decision.checks)
        if new_checks_priority > old_checks_priority:
            cancel_task(host, f"Failure with higher priority checks was happened ({new_decision.checks})")
            return True
    return False
