import logging
from contextlib import contextmanager

from six import wraps

import walle.fsm_stages
import walle.host_fsm.control
import walle.stages
from sepelib.core import config
from sepelib.core.exceptions import LogicalError
from walle import restrictions, audit_log, constants as walle_constants
from walle.admin_requests.constants import RequestTypes
from walle.clients import deploy
from walle.clients.cms import get_cms_task_type, CmsTaskRejectedError, CmsError, make_maintenance_info
from walle.clients.eine import ProfileMode
from walle.clients.juggler import JugglerDowntimeName, JugglerClient
from walle.clients.utils import strip_api_error
from walle.constants import (
    EINE_PROFILES_WITH_DC_SUPPORT,
    PROVISIONER_LUI,
    PROVISIONER_EINE,
    NetworkTarget,
    EINE_IPXE_FROM_CDROM_PROFILE,
    FLEXY_EINE_PROFILE,
)
from walle.errors import ResourceConflictError, InvalidHostStateError, OperationRestrictedError
from walle.expert.failure_types import FailureType
from walle.expert.rules.utils import repair_hardware_params
from walle.expert.types import CheckGroup
from walle.fsm_stages.constants import EineProfileOperation
from walle.hosts import Task, HostState, get_host_query, TaskType
from walle.locks import HostInterruptableLock
from walle.models import timestamp
from walle.operations_log.constants import Operation
from walle.scenario.cms import get_scenario_info_for_cms
from walle.stages import Stages, StageTerminals, Stage
from walle.statbox.contexts import host_context
from walle.statbox.loggers import fsm_logger
from walle.util import notifications
from walle.util.misc import drop_none, args_as_dict, closing_ctx, dummy_context, parallelize_execution

log = logging.getLogger(__name__)


DEFAULT_PROFILE_TAGS = ProfileMode.HIGHLOAD_TEST_ADD_TAGS | ProfileMode.DISK_RW_TEST_ADD_TAGS


class RequestRejectedByCmsError(ResourceConflictError):
    def __init__(self, cms_error):
        super().__init__("The request has been rejected by CMS: {}", strip_api_error(str(cms_error)))


def get_profile_stages(
    operation,
    profile,
    profile_tags=None,
    profile_modes=None,
    allow_upgrade=False,
    repair_request_severity=None,
    operation_type=None,
    force_update_network_location=False,
    verify_network_location=False,
):
    if operation == EineProfileOperation.PREPARE:
        stage_name = Stages.PROFILE
        destructive = True
        terminators = None
    elif operation == EineProfileOperation.RELEASE:
        stage_name = Stages.PROFILE
        destructive = True
        terminators = {StageTerminals.FAIL: StageTerminals.SKIP}
    elif operation == EineProfileOperation.PROFILE:
        stage_name = Stages.PROFILE
        destructive = False
        terminators = None
    elif operation == EineProfileOperation.DEPLOY:
        stage_name = Stages.DEPLOY
        destructive = True
        terminators = None
        if allow_upgrade:
            terminators = {StageTerminals.DEPLOY_FAILED: StageTerminals.DISK_RW_AND_REDEPLOY}
    else:
        raise LogicalError()

    builder = StageBuilder()

    # Eine may not have enough permissions to switch host's port to the deploy VLAN, so help her.
    with builder.nested(Stages.NETWORK) as network_stages:
        verify_fail_terminals = {True: StageTerminals.SKIP, False: StageTerminals.COMPLETE_PARENT}
        update_network_stage = network_stages.stage(
            Stages.UPDATE_NETWORK_LOCATION,
            terminators={StageTerminals.FAIL: verify_fail_terminals[verify_network_location]},
            force=force_update_network_location,
        )
        if verify_network_location:
            # try to verify that we deal with the implied host
            with builder.nested(Stages.VERIFY_NETWORK_LOCATION) as verify_stages:
                # verify that switch/port is not used with other host
                # if network data was successfully updated on previous stage, go to swith VLAN
                # otherwise power on host and waiting for network updating from racktables

                verify_stages.stage(Stages.VERIFY_SWITCH_PORT, update_network_stage_id=update_network_stage.get_uid())
                verify_stages.stage(Stages.POWER_ON)
                verify_stages.stage(Stages.UPDATE_NETWORK_LOCATION, wait_for_racktables=True)

        # on success vlan switching: proceed with eine profile.
        # on failed vlan switching: proceed with eine profile, ignore failure.
        vlans_stage = network_stages.stage(
            Stages.SWITCH_VLANS,
            network=NetworkTarget.SERVICE,
            terminators={StageTerminals.SWITCH_MISSING: StageTerminals.COMPLETE_PARENT},
        )

    with builder.nested(stage_name) as profile_stages:
        if profile not in EINE_PROFILES_WITH_DC_SUPPORT:
            # do not power on the host, it should not boot from disk.
            profile_stages.stage(Stages.POWER_OFF, soft=not destructive)

        profile_stages.stage(
            Stages.EINE_PROFILE,
            terminators=terminators,
            operation=operation,
            profile=profile,
            profile_tags=profile_tags,
            vlans_stage=vlans_stage.get_uid(),
            repair_request_severity=repair_request_severity,
        )

        if profile_modes:
            profile_stages.stage(
                Stages.LOG_COMPLETED_OPERATION,
                operation=Operation.PROFILE.type if operation_type is None else operation_type,
                params={"modes": profile_modes},
            )

    return builder.get_stages()


def get_deploy_stages(
    deploy_configuration,
    config_forced=False,
    extra_vlans=None,
    with_autohealing=False,
    with_2nd_time_node=False,
    profile_tags=None,
):
    """Returns deploy stages for the specified provisioner."""
    builder = StageBuilder()
    if deploy_configuration.provisioner == PROVISIONER_LUI:
        with builder.nested(
            Stages.DEPLOY, config=deploy_configuration._asdict(), config_forced=config_forced
        ) as deploy_stage:
            deploy_stage.stage(Stages.GENERATE_CUSTOM_DEPLOY_CONFIG)
            deploy_stage.stage(Stages.DROP_EINE_PROFILE)
            deploy_stage.stage(Stages.SWITCH_VLANS, network=NetworkTarget.DEPLOY, extra_vlans=extra_vlans)
            deploy_stage.stage(Stages.ISSUE_CERTIFICATE)
            # run power-off stage for hosts without iPXE too, to detect and report IPMI errors
            deploy_stage.stage(Stages.POWER_OFF)
            deploy_stage.stage(Stages.ASSIGN_LUI_CONFIG)

            if deploy_configuration.ipxe:
                deploy_stage.add_stages(get_power_on_stages(pxe=True))
            else:
                # if everything is ok, this profile runs about 5 min.
                # It only fails if it can't boot so we are safe to wait for this profile to complete.
                # Host may successfully deploy while Wall-E is waiting here. It is not a problem.
                # If host fails to deploy, Wall-E will still find it when he gets there.
                deploy_stage.stage(
                    Stages.EINE_PROFILE, operation=EineProfileOperation.DEPLOY, profile=EINE_IPXE_FROM_CDROM_PROFILE
                )
                deploy_stage.stage(
                    Stages.LOG_COMPLETED_OPERATION,
                    operation=Operation.PROFILE.type,
                    params=dict(operation=EineProfileOperation.DEPLOY),
                )

            terminators = None
            if with_autohealing or with_2nd_time_node:
                terminators = {
                    StageTerminals.SUCCESS: StageTerminals.COMPLETE_PARENT,
                    StageTerminals.DEPLOY_FAILED: StageTerminals.SKIP,
                }

            deploy_stage.stage(Stages.LUI_INSTALL, terminators=terminators)
            deploy_stage.stage(Stages.LUI_DEACTIVATE)

            if with_autohealing:
                vlans_stage = deploy_stage.stage(
                    Stages.SWITCH_VLANS,
                    network=NetworkTarget.SERVICE,
                    terminators={StageTerminals.SWITCH_MISSING: StageTerminals.SKIP},
                )

                profile_terminators = {
                    StageTerminals.SUCCESS: StageTerminals.RETRY_ACTION,
                    StageTerminals.NO_ERROR_FOUND: (
                        # TODO: retry action on "no err found",
                        # need to check retry count and skip if too many retries
                        StageTerminals.SKIP
                        if with_2nd_time_node
                        else StageTerminals.RETRY_ACTION
                    ),
                }

                # WALLE-4569 Merge host's project tags and profile mode tags.
                # This can me a problem to users who expect full control on tags passing to Eine.
                deploy_stage_profile_tags = DEFAULT_PROFILE_TAGS.copy()
                if profile_tags:
                    deploy_stage_profile_tags |= set(profile_tags)

                deploy_stage.stage(
                    Stages.EINE_PROFILE,
                    operation=EineProfileOperation.PROFILE,
                    profile=FLEXY_EINE_PROFILE,
                    profile_tags=sorted(deploy_stage_profile_tags),
                    vlans_stage=vlans_stage.get_uid(),
                    terminators=profile_terminators,
                )
                deploy_stage.stage(
                    Stages.LOG_COMPLETED_OPERATION,
                    operation=Operation.PROFILE.type,
                    params=dict(operation=EineProfileOperation.PROFILE),
                )

                if with_2nd_time_node:
                    decision_params = repair_hardware_params(
                        failure_type=FailureType.SECOND_TIME_NODE,
                        operation_type=Operation.REPORT_SECOND_TIME_NODE.type,
                        request_type=RequestTypes.SECOND_TIME_NODE.type,
                        redeploy=False,
                    )
                    deploy_stage.stage(
                        Stages.HW_REPAIR,
                        decision_params=decision_params,
                        decision_reason="Host failed to boot from PXE.",
                        terminators={StageTerminals.SUCCESS: StageTerminals.RETRY_ACTION},
                    )

    elif deploy_configuration.provisioner == PROVISIONER_EINE:
        builder.add_stages(
            get_profile_stages(
                EineProfileOperation.DEPLOY,
                profile=deploy_configuration.config,
                profile_tags=deploy_configuration.tags,
                allow_upgrade=with_autohealing,
            )
        )
    else:
        raise LogicalError()

    builder.stage(Stages.LOG_COMPLETED_OPERATION, operation=Operation.REDEPLOY.type)

    return builder.get_stages()


def get_power_on_stages(pxe=None, check_post_code=False):
    if check_post_code:
        sb = StageBuilder()

        with sb.nested(Stages.POWER_ON_COMPOSITE) as power_on_stages:
            power_on_stages.stage(name=Stages.POWER_ON, check_post_code=True, pxe=pxe)
            power_on_stages.stage(name=Stages.POWER_OFF)
            power_on_stages.stage(name=Stages.POWER_ON, check_post_code=True, upgrade_to_profile=True, pxe=pxe)

        return sb.get_stages()
    else:
        return [Stage(name=Stages.POWER_ON, params=drop_none(dict(pxe=pxe)) or None)]


def prepare_host_adding_stages(task, host, reason, deploy_configuration, dns, check):
    task.set_probation(reason)
    task.power_on()
    task.get_network_location(host, full=False)
    task.switch_vlans(NetworkTarget.PROJECT)

    if dns:
        task.setup_dns()

    if deploy_configuration.provisioner == walle_constants.PROVISIONER_LUI:
        task.lui_setup(deploy_configuration.config)

    if check is None or check:
        task.monitor(CheckGroup.NETWORK_AVAILABILITY)


def check_post_code_allowed(host, task_type, with_auto_healing):
    reasons = []
    # override any current auto_healing if task was created by automation
    with_auto_healing = with_auto_healing or task_type == TaskType.AUTOMATED_HEALING

    if not with_auto_healing:
        reasons.append("Automated healing disabled")

    try:
        restrictions.check_restrictions(host, (restrictions.PROFILE, restrictions.AUTOMATED_PROFILE))
        restricted = False
    except OperationRestrictedError as error:
        restricted = True
        reasons.append(str(error))

    decision = not restricted and with_auto_healing

    log.debug(
        "Host {}: check_post_code is allowed: {}, "
        "task_type is: {}, "
        "with_auto_healing is: {}, "
        "restricted: {}".format(host.name, decision, task_type, with_auto_healing, restricted)
    )

    return decision, "; ".join(reasons) or None


def new_task(
    issuer,
    task_type,
    audit_entry,
    stages,
    host,
    next_check=None,
    target_status=None,
    ignore_cms=False,
    keep_downtime=False,
    keep_cms_task=False,
    disable_admin_requests=False,
    monitor_on_completion=False,
    with_auto_healing=None,
    health_status_accuracy=None,
    checks_to_monitor=None,
    monitoring_timeout=None,
    checks_for_use=None,
    use_cloud_post_processor=False,
    profile_after_task=False,
    redeploy_after_task=False,
    cms_task_id=None,
):
    stages.append(
        Stage(
            name=Stages.RESET_HEALTH_STATUS,
            params=drop_none({"health_status_accuracy": health_status_accuracy}) or None,
        )
    )

    if not disable_admin_requests:
        stages.append(Stage(name=Stages.CANCEL_ADMIN_REQUESTS))

    if monitor_on_completion:
        if checks_to_monitor is None:
            raise LogicalError()

        # monitoring timeout is not a final time limit, it is an extra time added to the default value.
        stages.append(
            Stage(
                name=Stages.MONITOR,
                params=drop_none(
                    {
                        "checks": sorted(set(checks_to_monitor)),
                        "monitoring_timeout": monitoring_timeout,
                        "checks_for_use": checks_for_use,
                    }
                ),
            )
        )

    if use_cloud_post_processor:
        stages.append(
            Stage(
                name=Stages.CLOUD_POST_PROCESSOR,
                params=drop_none(
                    {"profile_after_task": profile_after_task, "redeploy_after_task": redeploy_after_task}
                ),
            )
        )

    stages = walle.stages.set_uids(stages)

    if next_check is None:
        next_check = _get_new_task_next_check(task_type)

    origin_decision = None
    if host.health:
        origin_decision = host.health.clone_decision()
    task = Task(
        task_id=Task.next_task_id(),
        type=task_type,
        owner=issuer,
        audit_log_id=audit_entry.id,
        status="pending",
        stages=stages,
        target_status=target_status,
        ignore_cms=ignore_cms,
        keep_downtime=keep_downtime,
        keep_cms_task=keep_cms_task,
        disable_admin_requests=disable_admin_requests,
        enable_auto_healing=with_auto_healing,
        next_check=next_check,
        revision=0,
        cms_task_id=cms_task_id,
        decision=origin_decision,
    )

    return task


def _get_new_task_next_check(task_type):
    next_check = timestamp()
    if config.get_value("run.debug"):
        return next_check

    # TODO: very bad and unreliable temporary solution
    # When we limit the number of concurrently processing tasks we rely on task ID monotonically increasing numbers
    # when selecting N most early enqueued tasks. But there is a chance that when a lot of hosts are enqueued there will
    # be a race condition between task ID generation and committing it into the database. To overcome this race
    # condition we enqueue all tasks with next check scheduled at a few seconds to the future.
    #
    # MongoDB 2.6 has $currentDate operator which can fully solve the problem
    # Is it really a problem?
    next_check += 5

    # TODO: do something with this hackaround
    # Increase delay for processing of automated tasks. See walle.host_fsm.fsm._cancel_task_if_needed() for details
    # Probably adding a cool-off timeout to `acquire-permission` stage would be appropriate.
    if task_type == TaskType.AUTOMATED_HEALING:
        next_check += 5

    return next_check


def check_state_and_get_query(host, issuer, task_type, ignore_maintenance, allowed_states, allowed_statuses, **kwargs):
    if host.state not in allowed_states or host.status not in allowed_statuses:
        raise InvalidHostStateError(host, allowed_states=allowed_states, allowed_statuses=allowed_statuses)

    if task_type == TaskType.AUTOMATED_HEALING:
        # protect from race between triage and screening:
        # only process decision if it did not change
        kwargs["decision_status"] = host.decision_status

    return get_host_query(issuer, ignore_maintenance, allowed_states, allowed_statuses=allowed_statuses, **kwargs)


def reject_request_if_needed(issuer, task_type, host, cms_action, ignore_cms):
    if task_type != TaskType.MANUAL or ignore_cms:
        return

    cms_api_client = host.get_cms()
    if cms_api_client is None:  # YP project
        return

    # Use fake task to not steal time from _TASK_ENQUEUE_TIMEOUT
    fake_task = Task(task_id=Task.next_task_id())

    cms_task_id = fake_task.get_cms_task_id()
    maintenance_info = make_maintenance_info(cms_task_id, cms_action, [host.name], issuer=issuer)
    scenario_info = get_scenario_info_for_cms(host)

    try:
        cms_api_client.add_task(
            id=cms_task_id,
            type=get_cms_task_type(task_type),
            issuer=issuer,
            action=cms_action,
            hosts=[host.name],
            maintenance_info=maintenance_info,
            scenario_info=scenario_info,
            dry_run=True,
        )
    except CmsTaskRejectedError as e:
        raise RequestRejectedByCmsError(e)
    except CmsError:
        log.exception("%s: Preliminary CMS request failed, scheduling the task.", host.human_name())


def _audited(only=None, exclude=None, **static):
    """Capture method parameters and save them into the audit data dict on the class instance.
    The data is to be used as a payload for audit log.
    """

    def decorator(method):
        @wraps(method)
        def wrapped_method(self, *args, **kwargs):
            params = args_as_dict(method, self, *args, **kwargs)
            keep = params.keys() - set(exclude or [])
            keep.discard("self")
            if only is not None:
                keep &= set(only)

            params = {k: v for k, v in params.items() if k in keep}
            if static:
                params.update(static)

            self.add_audit(**params)
            return method(self, *args, **kwargs)

        return wrapped_method

    return decorator


class StageBuilder:
    def __init__(self):
        self._stages = []

    def stage(self, name, stages=None, terminators=None, **params):
        stage = Stage(name=name, stages=stages, params=drop_none(params) or None, terminators=terminators)
        self._stages.append(stage)

        return stage

    def add_stages(self, stages):
        self._stages.extend(stages)

    def get_stages(self):
        return self._stages[:]

    @contextmanager
    def nested(self, name, **params):
        nested = StageBuilder()
        try:
            yield nested
        finally:
            self.stage(name, stages=nested.get_stages(), **params)


class TaskHelper:
    """Incapsulates procedures of task parameter accumulation and task creation.
    Task parameters may be used as a payload for a audit log.
    """

    def __init__(self, issuer, task_type, host, reason=None, monitor_on_completion=True):
        # required parameters
        self._issuer = issuer
        self._task_type = task_type
        self._host = host

        # optional parameter instantiated on call
        self._reason = reason
        self._ignore_cms = None
        self._cms_action = None
        self._keep_downtime = False

        self._checks_to_monitor = None
        self._health_status_accuracy = None
        self._monitoring_timeout = None
        self._monitor_on_completion = monitor_on_completion

        self._use_cloud_post_processor = False
        self._profile_after_task = False
        self._redeploy_after_task = False

        self._target_status = None

        # internal
        self._builder = StageBuilder()
        self._audit_data = {}
        self._audit_entry = None

    def acquire_permission(
        self,
        action=None,
        task_group=None,
        comment=None,
        ignore_cms=False,
        force_new_cms_task=None,
        extra=None,
        decision=None,
    ):
        self._ignore_cms = ignore_cms
        self._cms_action = action
        self._builder.stage(
            Stages.ACQUIRE_PERMISSION,
            action=action,
            task_group=task_group,
            comment=comment,
            force_new_cms_task=force_new_cms_task,
            extra=extra,
            failure=decision.failures[0] if decision and decision.failures else None,
            failure_type=decision.failure_type if decision else None,
            check_names=decision.checks if decision and decision.checks else None,
        )

    def set_downtime(self, juggler_downtime_name=JugglerDowntimeName.DEFAULT):
        """
        :param limited: specify if juggler downtime should have a time limit.
        :type limited: bool
        """
        self._builder.stage(Stages.SET_DOWNTIME, juggler_downtime_name=juggler_downtime_name)

    def keep_downtime(self, value=True):
        self._keep_downtime = value

    def set_maintenance(self, ticket_key, timeout_time, timeout_status, operation_state, reason=None):
        self.keep_downtime()

        self._builder.stage(
            Stages.SET_MAINTENANCE,
            ticket_key=ticket_key,
            timeout_time=timeout_time,
            timeout_status=timeout_status,
            operation_state=operation_state,
            reason=reason,
        )

    @_audited(only=(), state=HostState.ASSIGNED)
    def set_assigned(self, reason=None):
        self._builder.stage(Stages.SET_ASSIGNED, reason=reason)

    @_audited()
    def assign_bot_project(self, bot_project_id):
        if bot_project_id:
            self._builder.stage(Stages.ASSIGN_BOT_PROJECT, bot_project_id=bot_project_id)

    def get_network_location(self, host, full=False):
        """
        :param host: Host to create task for. Need to deduce profile configuration (optional, for full mode)
        :param full: Enable "full mode": run through eine profile and wait for netmap if info is not readily available
        """
        network_update_configuration = host.deduce_profile_configuration(profile_mode=ProfileMode.SWP_UP)
        self._builder.add_stages(walle._tasks.stages.get_network_update_stages(network_update_configuration, full=full))

    def switch_vlans(self, network, extra_vlans=None):
        terminators = None
        if network == NetworkTarget.PARKING:
            terminators = {StageTerminals.SWITCH_MISSING: StageTerminals.SKIP}

        self._builder.stage(Stages.SWITCH_VLANS, network=network, terminators=terminators, extra_vlans=extra_vlans)

    def setup_dns(self, clear=None, create=None):
        """
        :param clear: just remove old DNS records without adding new.
        :param create: operation is expected to only create new records and not to delete or update existing.
        :type clear: bool
        """
        self._builder.stage(Stages.SETUP_DNS, clear=clear, create=create)

    def power_off(self, soft=True, skip_errors=False):
        kwargs = {"soft": soft}
        if skip_errors:
            kwargs["terminators"] = {StageTerminals.FAIL: StageTerminals.SKIP}
        self._builder.stage(Stages.POWER_OFF, **kwargs)

    @_audited()
    def optional_power_off(self, reboot=False):
        if reboot:
            self.power_off()
            return closing_ctx(self.power_on)
        else:
            return dummy_context()

    def power_on(self, check_post_code=False):
        self._builder.add_stages(get_power_on_stages(check_post_code=check_post_code))

    def reboot(self, check_post_code=False):
        # NB: this is not a full-blown reboot task, this is just a bare-bones reboot stage
        with self._builder.nested(Stages.REBOOT) as reboot_builder:
            reboot_builder.stage(Stages.POWER_OFF, soft=True)
            reboot_builder.add_stages(get_power_on_stages(check_post_code=check_post_code))

    def profile(self, profile_configuration):
        stages = get_profile_stages(EineProfileOperation.PROFILE, *profile_configuration)
        self._builder.add_stages(stages)

    def redeploy(self, deploy_configuration):
        stages = get_deploy_stages(deploy_configuration, with_autohealing=False)

        if deploy_configuration.provisioner == PROVISIONER_LUI:
            self._health_status_accuracy = deploy.COMPLETED_STATUS_TIME_ACCURACY

        self._builder.add_stages(stages)

    def lui_setup(self, deploy_config):
        self._builder.stage(Stages.LUI_SETUP, config=deploy_config)

    def reset_bmc(self):
        self._builder.stage(Stages.RESET_BMC)

    @_audited(exclude=["decision"])
    def repair_hardware(self, decision, decision_params, reason):
        self._builder.stage(
            Stages.HW_REPAIR,
            decision_params=decision_params,
            decision_reason=reason,
            data={"orig_decision": decision.to_dict()},
        )

    def provide_diagnostic_host_access(self):
        self._builder.stage(Stages.PROVIDE_DIAGNOSTIC_HOST_ACCESS)

    def monitor(self, checks, timeout=None):
        self._checks_to_monitor = checks
        self._monitoring_timeout = timeout

    def cloud_post_processor(self, need_profile, need_redeploy):
        self._use_cloud_post_processor = True
        self._profile_after_task = need_profile
        self._redeploy_after_task = need_redeploy

    def log_completed_operation(self, operation, **params):
        self._builder.stage(Stages.LOG_COMPLETED_OPERATION, operation=operation.type, params=params or None)

    def target_status(self, status):
        self._target_status = status

    def add_audit(self, **data):
        self._audit_data.update(data)

    def _get_audit(self, reason=None):
        task_args = {
            "issuer": self._issuer,
            "project_id": self._host.project,
            "inv": self._host.inv,
            "name": self._host.name,
            "host_uuid": self._host.uuid,
        }
        return drop_none(dict(self._audit_data, reason=reason, scenario_id=self._host.scenario_id, **task_args))

    @contextmanager
    def audit(self, on_action, reason=None):
        """Take and audit_log action, like `on_delete_host` and fill it with required data.
        Yield an audit_log context manager instance so that outer scope could use it too.
        """
        self._reason = reason or self._reason
        with on_action(**self._get_audit(self._reason)) as audit_entry:
            self._audit_entry = audit_entry
            yield audit_entry

    def probe_cms(self):
        if self._cms_action is None:
            # the task have no CMS stage, no need to ask CMS for permission.
            raise LogicalError()
        reject_request_if_needed(self._issuer, self._task_type, self._host, self._cms_action, self._ignore_cms)

    def _get_task_params(self):
        return {
            "issuer": self._issuer,
            "task_type": self._task_type,
            "audit_entry": self._audit_entry,
            "stages": self._builder.get_stages(),
            "host": self._host,
            "target_status": self._target_status,
            "ignore_cms": self._ignore_cms,
            "keep_downtime": self._keep_downtime,
            "monitor_on_completion": self._monitor_on_completion,
            "checks_to_monitor": self._checks_to_monitor,
            "monitoring_timeout": self._monitoring_timeout,
            "health_status_accuracy": self._health_status_accuracy,
            "use_cloud_post_processor": self._use_cloud_post_processor,
            "profile_after_task": self._profile_after_task,
            "redeploy_after_task": self._redeploy_after_task,
        }

    def task(self, audit_entry=None):
        self._audit_entry = audit_entry or self._audit_entry
        if self._audit_entry is None:
            # need to use with tb.audit() or pass audit_entry directly.
            raise LogicalError()

        return new_task(**self._get_task_params())

    def enqueue(self, host_query, status_kwargs):
        if not self._host.modify(host_query, set__task=self.task(), **status_kwargs):
            raise InvalidHostStateError(self._host)

        notifications.on_task_enqueued(self._issuer, self._host, self._reason)

    def swap_current_task(self, host_query, status_kwargs):
        prev_host = self._host.copy()

        with HostInterruptableLock(self._host.uuid, self._host.tier):
            self.enqueue(dict(host_query, task__task_id=self._host.task.task_id), status_kwargs)

            try:
                on_task_cancelled(self._issuer, prev_host, self._reason)
            except Exception:
                log.exception("Failed to execute task cancellation actions")
                # that's sad, but let gc deal with it, we can not do anything here,
                # and neither does the user


class UserTaskHelper(TaskHelper):
    def __init__(
        self,
        issuer,
        task_type,
        host,
        ignore_maintenance=False,
        disable_admin_requests=False,
        with_auto_healing=None,
        reason=None,
    ):

        super().__init__(issuer, task_type, host, reason)

        # optional parameters
        self._ignore_maintenance = ignore_maintenance
        self._disable_admin_requests = disable_admin_requests
        self._with_auto_healing = with_auto_healing
        self._monitor_on_completion = False

    def monitor(self, checks, timeout=None):
        self._monitor_on_completion = True
        super().monitor(checks, timeout)

    def _get_audit(self, reason=None):
        audit = super()._get_audit(reason)

        audit.update(
            drop_none(
                {
                    "ignore_maintenance": self._ignore_maintenance,
                    "disable_admin_requests": self._disable_admin_requests,
                    "check": self._monitor_on_completion,
                    "with_auto_healing": self._with_auto_healing,
                    "ignore_cms": self._ignore_cms,
                }
            )
        )
        return audit

    def _get_task_params(self):
        params = super()._get_task_params()
        params.update(
            {
                "disable_admin_requests": self._disable_admin_requests,
                "monitor_on_completion": self._monitor_on_completion,
                "with_auto_healing": self._with_auto_healing,
            }
        )

        return params

    def set_probation(self, reason):
        self._builder.stage(Stages.SET_PROBATION, reason=reason)


class CheckDnsTaskHelper(UserTaskHelper):
    def _get_audit(self, reason=None):
        audit = super()._get_audit(reason)
        audit.pop("ignore_cms", None)
        return audit


class ReportTaskHelper(TaskHelper):
    @_audited()
    def report_failure(self, checks, reason):
        self._builder.stage(Stages.REPORT, checks=checks, reason=reason)

    @_audited(exclude=["checks"])
    def repair_rack_failure(self, checks, reason):
        self._builder.stage(Stages.REPORT_RACK, checks=checks, reason=reason)

    @_audited(exclude=["checks"])
    def repair_rack_overheat(self, checks, reason):
        self._builder.stage(Stages.REPORT_RACK_OVERHEAT, checks=checks, reason=reason)


class DeleteHostTaskHelper(UserTaskHelper):
    def _get_audit(self, reason=None):
        audit = super()._get_audit(reason)
        audit.pop("check", None)
        return audit


def _network_target_project(host):
    project = host.get_project(fields=("vlan_scheme",))
    if project.vlan_scheme is None:
        raise ResourceConflictError("Project is not configured for automatic VLAN detecting.")
    return NetworkTarget.PROJECT


def on_task_cancelled(issuer, prev_host, reason=None):
    """Called on task cancellation.

    Implies an active lock on the host.
    """

    fsm_logger(walle_action="fsm_cancel_task", fsm_result="cancel", **host_context(prev_host)).log()

    walle.fsm_stages.common.cancel_host_stages(prev_host)
    audit_log.cancel_task(prev_host.task, reason)
    notifications.on_task_cancelled(issuer, prev_host, reason)

    keep_downtime = prev_host.state in HostState.ALL_DOWNTIME
    on_finished_task(prev_host, keep_downtime=keep_downtime)


def on_finished_task(prev_host, keep_downtime=False):
    """Called on task completion (successful or unsuccessful).

    Does all required finalization actions.
    """

    fsm_logger(walle_action="fsm_on_finished_task", **host_context(prev_host)).log()
    actions = []

    keep_downtime = keep_downtime or prev_host.state == HostState.MAINTENANCE
    if not keep_downtime and prev_host.on_downtime and prev_host.name is not None:
        actions.append(lambda: _clear_downtime_from_host(prev_host))

    clear_cms_task = prev_host.cms_task_id is None
    if clear_cms_task and not prev_host.task.ignore_cms:
        actions.append(lambda: _delete_task_from_cms(prev_host))

    parallelize_execution(*actions)


def _clear_downtime_from_host(host):
    log.warning("%s: Remove downtime in Juggler.", host.human_id())

    juggler = JugglerClient()
    try:
        juggler.clear_downtimes(host.name)

    except Exception as e:
        log.error("%s: Failed to remove downtime in Juggler: %s", host.human_id(), e)


def _delete_task_from_cms(host):
    cms_task_id = host.task.get_cms_task_id()

    try:
        cms_api_clients = host.get_cms_clients()

        for cms_client in cms_api_clients:
            if cms_client is not None:
                cms_client.delete_task(cms_task_id)
    except Exception as e:
        log.error("%s: Failed to remove %s task from CMS: %s", host.human_id(), cms_task_id, e)
