"""Contains common logic for stage processing."""

import logging
from collections import namedtuple

import walle.expert.dmc
import walle.operations_log.operations as operations_log
import walle.stages
import walle.tasks
from sepelib.core import config, constants
from sepelib.core.exceptions import Error, LogicalError
from walle import audit_log, authorization, constants as walle_constants, projects
from walle.clients import deploy, startrek
from walle.expert.decision import Decision
from walle.expert.types import WalleAction
from walle.hosts import Host, HostStatus, Task, get_host_query, HostState, deploy_configuration
from walle.models import timestamp
from walle.operations_log.constants import Operation
from walle.stages import Stage, Stages, StageTerminals
from walle.util import notifications
from walle.util.limits import parse_timed_limits
from walle.util.tasks import on_task_cancelled, on_finished_task

log = logging.getLogger(__name__)


NEXT_CHECK_ADVANCE_TIME = 20
"""
An amount of time on which we defer next check every time when attempt to lock a host to prevent attempts to lock this
host by other FSM.
"""

ADMIN_REQUEST_CHECK_INTERVAL = 10 * constants.MINUTE_SECONDS
"""Interval for checking admin request processing."""

MONITORING_PERIOD = constants.MINUTE_SECONDS
"""Host monitoring period."""

ERROR_CHECK_PERIOD = constants.MINUTE_SECONDS
"""Stage check period after some internal error."""

DEFAULT_RETRY_TIMEOUT = 20 * constants.MINUTE_SECONDS
"""Use some reasonable default timeout before trying again. Picked up 20 minutes because twenty's plenty."""

PARENT_STAGE_RETRY_STATUS = "retrying"
"""Display this status while waiting for the next retry of the parent stage."""


_STAGES = {}
_TERMINATORS = {}
_StageConfig = namedtuple("StageConfig", ["handler", "initial_status", "cancellation_handler", "error_handler"])


def register_stage(name, handler, initial_status=None, cancellation_handler=None, error_handler=None):
    """Registers a stage in FSM.

    :param name - name for the registered stage.
    :param handler - a handler that will be called as handler(host) by FSM to process the stage.
    :param initial_status - initial status that will be assigned to the stage when it will be entered.
    :param cancellation_handler - a handler that will be called as handler(host, stage) on task cancellation to cleanup
           some things. Attention: it will be called only once, all errors will be ignored and it mustn't modify the
           host object, because it will be already cancelled in the database.
    :param error_handler - a handler that will be called before the default stage handler when task have an error.
           This handler may try to recover an error or terminate the stage.
           It should return True if FSM can proceed with the default handler.
    """

    stage_config = _StageConfig(handler, initial_status, cancellation_handler, error_handler)
    if _STAGES.setdefault(name, stage_config) is not stage_config:
        raise Error("Unable to register '{}' stage: it already exists.", name)


def generate_stage_handler(status_handlers):
    def status_handler(host):
        stage = get_current_stage(host)

        try:
            handler = status_handlers[stage.status]
        except KeyError:
            raise Error("Got unexpected '{}' status for '{}' stage.", stage.status, stage.name)

        return handler(host)

    return status_handler


def handle_current_stage(host):
    stage = get_current_stage(host)

    error_handler = get_stage_error_handler(stage.name)

    if host.task.error and error_handler is not None:
        log.debug("%s: Running error handler for '%s' stage...", host.human_id(), stage.name)
        # suppose we do not want to return anything when recovered successfully
        # and we want to raise when not recovered success fully
        # and some time we just want to prevent default handler from being run.
        if not error_handler(host):
            return

    get_stage_handler(stage.name)(host)


def get_stage_handler(name):
    return _get_stage_config(name).handler


def get_stage_error_handler(name):
    return _get_stage_config(name).error_handler


def _get_stage_config(name, only_if_exists=False):
    stage_config = _STAGES.get(name)
    if stage_config is None and not only_if_exists:
        raise Error("Stage '{}' is not registered.", name)

    return stage_config


def _register_terminal(terminal, terminator):
    if _TERMINATORS.setdefault(terminal, terminator) is not terminator:
        raise Error("Unable to register terminator for exit {}: already have one.".format(terminal))


def _register_terminator(terminal):
    def decorator(terminator):
        _register_terminal(terminal, terminator)
        return terminator

    return decorator


def _get_terminator(terminal):
    return _TERMINATORS.get(terminal)


def cancel_task(host, reason):
    log.info("%s: Cancelling automated '%s' task: %s", host.human_id(), host.status, reason)

    issuer_walle = authorization.ISSUER_WALLE
    prev_host = host.copy()

    _commit_task_changes(
        host,
        unset__task=True,
        **Host.set_status_kwargs(host.state, HostStatus.READY, issuer_walle, host.task.audit_log_id, downtime=False)
    )

    on_task_cancelled(issuer_walle, prev_host, reason)


def get_limits_for_action(host, limit_name):
    project = host.get_project(fields=("host_limits",))
    try:
        limits_conf = project.host_limits[limit_name]
    except KeyError:
        log.info("Host limits configuration '%s' is missing, using defaults", limit_name)
        limits_conf = projects.get_default_host_limits()[limit_name]

    return parse_timed_limits(limits_conf)


def cancel_task_considering_operation_limits(host):
    """
    :return: True if task was cancelled, False if cancellation limit was hit
    """
    cancellation_limits = get_limits_for_action(host, "max_healing_cancellations")

    if not operations_log.check_limits(host, Operation.CANCEL_HEALING, cancellation_limits):
        log.info(
            "%s: Host is healthy but automated healing task cancellation limit has exceeded for this host.",
            host.human_id(),
        )
        return False

    orig_host = host.copy()
    cancel_task(host, "The task is not actual anymore - host is healthy.")

    operations_log.on_completed_operation(orig_host, Operation.CANCEL_HEALING.type)
    return True


def cancel_host_stages(host, suppress_internal_errors=True):
    try:
        if host.task.stage_uid is not None:
            for stage in reversed(host.task.stages):
                _cancel_stage(host, stage)
    except Exception:
        if suppress_internal_errors:
            log.exception("%s: Stage cancellation has crashed:", host.human_id())
        else:
            raise


def _cancel_stage(host, stage):
    is_parent = stage.stages is not None

    if is_parent:
        for child_stage in reversed(stage.stages):
            _cancel_stage(host, child_stage)

    stage_config = _get_stage_config(stage.name, only_if_exists=is_parent)
    if stage_config is not None:
        handler = stage_config.cancellation_handler
        if handler is not None:
            log.debug("%s: Cancelling '%s' stage...", host.human_id(), stage.name)

            try:
                handler(host, stage)
            except Exception:
                log.exception("%s: '%s' stage cancellation handler has crashed:", host.human_id(), stage.name)


def get_current_stage(host, only_if_exists=False):
    """Returns the current stage.

    Attention! Be very careful when using the returned object: committing task changes to DB fully updates host object
    with actual data making it to point to new stage objects, so subsequent changes to this object won't be committed
    by subsequent commit function calls.
    """

    if host.task.stage_uid is not None:
        return walle.stages.get_by_uid(host.task.stages, host.task.stage_uid)

    if only_if_exists:
        return None

    raise Error("Can't get current stage: there is no any active stage yet.")


def get_parent_stage(host, stage=None):
    """Returns parent stage of the specified stage or parent stage of the current stage if the stage is not specified.

    Attention! Be very careful when using the returned object: committing task changes to DB fully updates host object
    with actual data making it to point to new stage objects, so subsequent changes to this object won't be committed
    by subsequent commit function calls.
    """

    if stage is None:
        stage_uid = host.task.stage_uid
        if stage_uid is None:
            raise Error("Can't get parent of the current stage: there is no any active stage yet.")
    else:
        stage_uid = stage.uid

    return walle.stages.get_parent(host.task.stages, stage_uid)


def get_stage_by_uid(host, stage_uid):
    return walle.stages.get_by_uid(host.task.stages, stage_uid)


def enter_task(host):
    if not host.task.stages:
        raise Error("Logical error: {}: '{}' task contains no stages.", host.human_id(), host.status)

    stage = Stage(name="root", stages=host.task.stages)
    switch_to_stage(host, stage)


def _hang_on_stage(host, stage, error, status, check_after):
    _enter_stage(host, stage, status=status)
    commit_stage_changes(host, error=error, check_after=check_after)


def switch_to_stage(host, stage, error=None, check_after=None, check_now=False, commit_ticket=False):
    # Don't switch to parent tasks - deep into their children
    while stage.stages:
        stage = stage.stages[0]

    _enter_stage(host, stage)
    commit_stage_changes(
        host,
        error=error,
        check_after=check_after,
        check_now=check_now,
        extra_fields=["ticket"] if commit_ticket else None,
    )

    return get_current_stage(host)


def commit_stage_changes(
    host,
    status=None,
    status_message=None,
    error=None,
    extra_fields=None,
    check_after=None,
    check_now=False,
    check_at=None,
):
    revision = _change_task_state(
        host,
        stage_status=status,
        status_message=status_message,
        error=error,
        check_after=check_after,
        check_now=check_now,
        check_at=check_at,
    )

    fields_to_commit = {"task"}
    if extra_fields:
        fields_to_commit.update(extra_fields)

    update_kwargs = {}
    for field in fields_to_commit:
        value = getattr(host, field)

        if value is None:
            update_kwargs["unset__" + field] = True
        else:
            update_kwargs["set__" + field] = value

    _commit_task_changes(host, revision=revision, **update_kwargs)


def push_host_ticket(host, ticket_key):
    stage = get_current_stage(host)
    tickets = stage.get_data("tickets", [])
    if ticket_key in tickets:
        return False

    with audit_log.on_hardware_repair(
        authorization.ISSUER_WALLE,
        host.project,
        host.inv,
        host.name,
        host.uuid,
        reason=ticket_key,
        scenario_id=host.scenario_id,
        itdc_ticket=ticket_key,
    ) as log_entry:
        if host.ticket:
            if host.ticket not in tickets:

                stage.set_data("host_ticket", host.ticket)
                tickets.append(host.ticket)

            startrek.link_tickets(ticket_key, tickets, silent=True)

        stage.set_data("tickets", tickets + [ticket_key])
        host.ticket = ticket_key
        log_entry.complete()
        return True


def _restore_host_ticket(host, stage):
    if stage.has_data("host_ticket"):
        host.ticket = stage.get_data("host_ticket")
        return True
    elif stage.has_data("tickets"):
        del host.ticket
        return True

    return False


def _get_persistent_error(host, stage):
    if stage.has_data("stage_error"):
        return stage.get_data("stage_error")

    parent_stage = walle.stages.get_parent(host.task.stages, stage.uid)
    if parent_stage is not None:
        return _get_persistent_error(host, parent_stage)
    else:
        return None


def _set_persistent_error(stage, error):
    if error is not None:
        stage.set_data("stage_error", error)


def increase_error_count(host, stage, name, limit, error, prefix_error=True, fail_stage=True):
    """Increases the specified error counter and fails current stage if it's reached the maximum value."""

    error_count = stage.set_data(name, stage.get_data(name, 0) + 1)
    if error_count < limit:
        return True

    if fail_stage:
        if prefix_error:
            error = "Too many errors occurred during processing '{}' stage of '{}' task. Last error: {}".format(
                stage.name, host.status, error
            )

        fail_current_stage(host, error)

    return False


def increase_configurable_error_count(host, stage, name, limit_name, error, fail_stage=True):
    """Increases the specified error counter and fails current stage if it's reached the maximum value."""

    return increase_error_count(host, stage, name, config.get_value(limit_name), error, fail_stage=fail_stage)


def complete_current_stage(host):
    terminate_current_stage(StageTerminals.SUCCESS, host)


# TODO: move terminators out of here
@_register_terminator(StageTerminals.COMPLETE_PARENT)
def complete_parent_stage(host, stage=None, message=None, **kwargs):
    if message:
        log.warning("%s: %s Skipping stage.", host.human_id(), message)

    if stage is None:
        stage = get_current_stage(host)

    parent_stage = get_parent_stage(host, stage)
    if parent_stage is None:
        raise Error("Unable to complete parent stage: the current stage doesn't have a parent stage.")

    _leave_stage(host, stage)
    _terminate_stage(StageTerminals.SUCCESS, host, parent_stage)


# TODO: move terminators out of here
@_register_terminator(StageTerminals.SKIP)
@_register_terminator(StageTerminals.NO_ERROR_FOUND)
def skip_stage(host, stage, message=None, **kwargs):
    if message:
        log.warning("%s: %s Skipping stage.", host.human_id(), message)
    return complete_stage(host, stage)


# TODO: move terminators out of here
@_register_terminator(StageTerminals.SUCCESS)
def complete_stage(host, stage, ticket_restored=False, **kwargs):
    next_stage = walle.stages.get_next(host.task.stages, stage.uid)
    ticket_restored = _restore_host_ticket(host, stage) or ticket_restored
    _leave_stage(host, stage)

    if next_stage is None:
        parent_stage = walle.stages.get_parent(host.task.stages, stage.uid)

        if parent_stage is None:
            complete_task(host)
        else:
            complete_stage(host, parent_stage, ticket_restored)
    else:
        switch_to_stage(host, next_stage, check_now=True, commit_ticket=ticket_restored)


def retry_current_stage(host, error=None, persistent_error=None, check_after=None, check_now=False):
    current_stage = get_current_stage(host)

    _leave_stage(host, current_stage)
    _set_persistent_error(current_stage, persistent_error)

    switch_to_stage(host, current_stage, error=error, check_after=check_after, check_now=check_now)


def retry_parent_stage(host, error=None, check_after=None):
    """Retry parent stage, with optional delay and optional error message.
    The error will persist through subsequent retries until all children stages complete successfully.
    """
    current_stage = get_current_stage(host)
    retry_action(host, current_stage, error=error, check_after=check_after)


def complete_task(host):
    """Completes the task.

    Attention: Be very careful using this function: it expects that all task's stages have been processed
    (including reset-health-status stage) and doesn't reset health status again (with bigger min check times
    because of possible monitoring stage).
    """

    log.info("%s: '%s' task has been successfully completed.", host.human_id(), host.status)
    prev_host = host.copy()
    task = prev_host.task

    status = host.task.target_status or HostStatus.default(host.state)
    status_reason = host.status_reason if status != HostStatus.READY else None

    # Host in [maintenance, probation] should be on downtime until it leaves that state
    keep_downtime = task.keep_downtime or host.state in HostState.ALL_DOWNTIME

    update_kwargs = Host.set_status_kwargs(
        host.state, status, task.owner, task.audit_log_id, reason=status_reason, downtime=keep_downtime
    )

    update_kwargs["unset__task"] = True
    if status == HostStatus.READY:
        update_kwargs["unset__ticket"] = True

    _commit_task_changes(host, **update_kwargs)

    audit_log.complete_task(task)
    notifications.on_task_completed(prev_host)

    on_finished_task(prev_host, keep_downtime=keep_downtime)


def fail_current_stage(host, error):
    terminate_current_stage(StageTerminals.FAIL, host, error)


# TODO: move terminators out of here
@_register_terminator(StageTerminals.FAIL)
def fail_stage(host, stage, error, **kwargs):
    """Fails given stage. At this time it's equal to terminating the task by host deactivation.

    :raises AutomationDisabledError so all callers must take it into account.
    """
    del stage, kwargs

    reason = str(error)
    log.warning("%s: %s Consider host as dead.", host.human_id(), reason)

    parent_stage = get_parent_stage(host)

    if (
        # TODO: Implement cascading stage failing (from children to parents) with on_failure handlers in each stage to
        # eliminate such hackarounds
        parent_stage is not None
        and parent_stage.name == Stages.DEPLOY
        and host.provisioner == walle_constants.PROVISIONER_LUI
    ):
        log.warning("Deactivating %s...", host.human_id())

        try:
            deploy.get_client(deploy.get_deploy_provider(host.get_eine_box())).deactivate(host.name)
        except Exception as e:
            log.warning("Failed to deactivate %s in the deployment system: %s", host.human_id(), e)

        try:
            ipmi_client = host.get_ipmi_client()

            if ipmi_client.is_power_on():
                ipmi_client.reset()
            else:
                ipmi_client.power_on()
        except Exception as e:
            log.warning("Failed to reset %s: %s", host.human_id(), e)

    prev_host = host.copy()
    deactivate(host, reason=reason)

    audit_log.fail_task(prev_host.task, reason)
    notifications.on_dead_host(prev_host, reason, cc_user=prev_host.task.owner)
    on_finished_task(prev_host)


# TODO: move terminators out of here
@_register_terminator(StageTerminals.RETRY)
def retry_stage(host, stage, error, **kwargs):
    # use some reasonable default timeout before trying again.
    # Picked up 20 minutes because twenty's plenty.
    _leave_stage(host, stage)
    switch_to_stage(host, stage, error=error, check_after=DEFAULT_RETRY_TIMEOUT)


@_register_terminator(StageTerminals.RETRY_ACTION)
def retry_action(host, stage, error=None, **kwargs):
    """Retry action, part of which is performed by the current stage.
    Provide optional delay and optional error message.
    The error will persist through subsequent retries until all children stages complete successfully.
    """
    parent_stage = walle.stages.get_parent(host.task.stages, stage.uid)
    if parent_stage is None:
        raise Error("Can not retry parent stage because current stage does not have a parent.")

    _leave_stage(host, stage)
    _leave_stage(host, parent_stage)
    _set_persistent_error(parent_stage, error)

    check_after = kwargs.get("check_after", DEFAULT_RETRY_TIMEOUT)
    _hang_on_stage(host, parent_stage, error=error, status=PARENT_STAGE_RETRY_STATUS, check_after=check_after)


# TODO: move terminators out of here
@_register_terminator(StageTerminals.HIGHLOAD_AND_REDEPLOY)  # backward compatibility, WALLE-3342
@_register_terminator(StageTerminals.DISK_RW_AND_REDEPLOY)
def _upgrade_to_disk_rw_profile_with_redeploy(host, stage, error, **kwargs):
    log.warning("%s: Upgrading the task to run highload profile: %s", host.human_id(), error)

    project = host.get_project(fields=("id", "profile", "vlan_scheme"))
    if project.profile is None or project.vlan_scheme is None:
        return fail_stage(
            host,
            stage,
            error=(
                "Operation  failed: {}. Can't upgrade task to run highload profile: "
                "project is not configured for automated profile.".format(error)
            ),
        )

    error = "Upgrading the task to run highload profile: {}".format(error)

    ignore_cms = host.task.ignore_cms

    decision = Decision(WalleAction.DISK_RW_REDEPLOY, error)
    walle.expert.dmc.schedule_action(
        host, decision, error, from_monitoring=True, issuer=host.task.owner, ignore_cms=ignore_cms
    )


# TODO: move terminators out of here
@_register_terminator(StageTerminals.PROFILE)
def _upgrade_to_profile_task(host, stage, error, **kwargs):
    del stage, kwargs

    log.warning("%s: Upgrading the task to 'profile': %s", host.human_id(), error)
    error = "Upgrading the task to 'profile': {}".format(error)

    ignore_cms = host.task.ignore_cms

    decision = Decision(WalleAction.PROFILE, error, params={})
    walle.expert.dmc.schedule_action(
        host, decision, error, from_monitoring=True, issuer=host.task.owner, ignore_cms=ignore_cms
    )


# infinitely retry stage on SWITCH MISSING until location data eventually come from racktables with sync.
_register_terminal(StageTerminals.SWITCH_MISSING, retry_stage)
_register_terminal(StageTerminals.DEPLOY_FAILED, retry_action)


# TODO: move terminators out of here
# and convert this particular one into a task terminator
# by creating a dedicated FinishTask stage and
# assigning a `complete_task` function as a terminator for SUCCESS terminal.
@_register_terminator(StageTerminals.DELETE_HOST)
def _delete_host_terminator(host, stage, **kwargs):
    issuer = audit_log.get_task_issuer(host.task)
    log.info(
        "Host %s deletion has been requested by %s. Deleting...",
        host.human_id(),
        "[unknown]" if issuer is None else issuer,
    )

    # remove cms task id from host, to indicate that we shall not keep the task in cms api.
    del host.cms_task_id

    prev_host = host.copy()
    query = get_host_query(
        issuer,
        ignore_maintenance=True,
        allowed_states=HostState.ALL,
        allowed_statuses=[Operation.DELETE.host_status],
        inv=host.inv,
    )

    if not Host.objects(**query).modify(remove=True):
        raise Error("Unable to delete host: it's state have changed.")

    # In the `normal` case, when task finishes we remove task so that FSM knows the task is finished.
    # FSM holds the same `host` object as we have. When we run host.modify, mongoengine updates the object.
    # Here we remove the host itself, but we don't receive an updated version.
    # This means we need to change relevant attributes manually, so that FSM knew the task is finished.
    del host.task

    audit_log.complete_task(prev_host.task)
    notifications.on_deleted_host(prev_host, issuer, reason=None)
    on_finished_task(prev_host)


def deactivate(host, reason):
    """
    Deactivates the host. For now deactivation process consists only from setting the dead status to the host, but it's
    possible that in the future it will contain some other destructive actions like for example VLAN switching. This
    makes us think about deactivation as about other action like reboot/redeploy.
    """

    log.warning("Mark %s as dead.", host.human_id())
    _commit_task_changes(
        host,
        unset__task=True,
        **Host.set_status_kwargs(
            host.state,
            HostStatus.DEAD,
            authorization.ISSUER_WALLE,
            host.task.audit_log_id,
            downtime=False,
            reason=reason,
        )
    )


def terminate_current_stage(terminal, host, *args, **kwargs):
    stage = get_current_stage(host)
    return _terminate_stage(terminal, host, stage, *args, **kwargs)


def _terminate_stage(terminal, host, stage, *args, **kwargs):
    if stage.terminators is not None and terminal in stage.terminators:
        terminal = stage.terminators[terminal]

    terminator = _get_terminator(terminal)
    if terminator is None:
        raise LogicalError

    return terminator(host, stage, *args, **kwargs)


def _commit_task_changes(host, revision=None, **kwargs):
    if not host.modify(walle.tasks.host_query(host, revision=revision), **kwargs):
        raise Error("Unable to commit host state change: it doesn't have an expected state already.")


def _leave_stage(host, stage):
    del stage.status
    del stage.status_time
    del stage.temp_data

    # Just a stub that will be changed a moment later on entering a next stage
    host.task.status = "unknown"


def _get_initial_status(stage):
    stage_config = _get_stage_config(stage.name)

    status = stage_config.initial_status
    if status is not None:
        if callable(status):
            status = status(stage)

    return status


def _enter_stage(host, stage, status=None):
    if status is None:
        status = _get_initial_status(stage)

    _set_current_stage_status(host, stage, status)

    host.task.stage_uid = stage.uid
    host.task.stage_name = stage.name


def _set_current_stage_status(host, stage, stage_status=None):
    if stage_status is None:
        del stage.status
    else:
        stage.status = stage_status

    stage.status_time = timestamp()
    task_status = Task.make_status(stage.name, stage_status)
    host.task.status = task_status

    log.info("%s: '%s' task: Transition to '%s' stage.", host.human_id(), host.status, task_status)


def _change_task_state(
    host, stage_status=None, status_message=None, error=None, check_after=None, check_now=False, check_at=None
):
    task = host.task
    stage = get_current_stage(host)

    if stage_status is not None:
        _set_current_stage_status(host, stage, stage_status)

    with_next_check = True

    if (check_after is not None) + check_now + (check_at is not None) > 1:
        raise LogicalError
    elif check_after is not None:
        task.next_check = timestamp() + check_after
    elif check_now:
        task.next_check = timestamp()
    elif check_at is not None:
        task.next_check = check_at
    else:
        with_next_check = False

    if status_message is None:
        del task.status_message
    else:
        task.status_message = str(status_message)

    if error is None:
        error = _get_persistent_error(host, stage)

    if error is None:
        # We need to reset task error after each successful iteration, but not when we just commit some changes.
        # Consider `next_check` as a sign of successful completion of the iteration.
        if with_next_check:
            del task.error
    else:
        task.error = str(error)

    cur_revision = task.revision
    task.revision += 1

    return cur_revision


def get_stage_deploy_configuration(deploy_stage):
    if deploy_stage.has_data("config_override"):
        return deploy_configuration(*deploy_stage.get_data("config_override"))

    if deploy_stage.has_param("config"):
        deploy_stage_conf = deploy_stage.get_param("config")
        if isinstance(deploy_stage_conf, dict):  # WALLE-2847
            return deploy_configuration(**deploy_stage_conf)
        elif isinstance(deploy_stage_conf, list):
            return deploy_configuration(*deploy_stage_conf)
        else:
            raise LogicalError()

    # No deploy config at all, it is an error.
    raise LogicalError
