"""Host task management."""

import logging

from walle import restrictions, audit_log
from walle._tasks.task_creator import create_new_task
from walle.clients.cms import CmsTaskAction
from walle.errors import RequestValidationError, InvalidHostStateError
from walle.hosts import Host
from walle.locks import HostInterruptableLock
from walle.util import notifications
from walle.util.tasks import check_state_and_get_query, reject_request_if_needed, on_task_cancelled

log = logging.getLogger(__name__)


def schedule_task(host, task_args, get_stages_func):
    task_provider = TaskProvider(host, task_args)
    audit_log_params = task_args.get_task_params()

    if task_args.from_current_task:
        query = task_provider.host_query()
    else:
        query = check_state_and_get_query(
            host,
            task_args.issuer,
            task_args.task_type,
            task_args.ignore_maintenance,
            allowed_states=task_args.allowed_states,
            allowed_statuses=task_args.allowed_statuses,
        )

    task_provider.validate_args()
    with audit_log.create(**audit_log_params) as audit_entry:
        task_provider.log()

        if task_args.use_cloud_post_processor and (task_args.redeploy_after_task or task_args.profile_after_task):
            task_args.cms_action = CmsTaskAction.REDEPLOY

        # Do not check cms on host with empty hostname
        if task_args.type not in {audit_log.TYPE_PREPARE_HOST}:
            task_provider.probe_cms()

        sb = get_stages_func(task_args)

        if task_args.keep_task_id:
            task_id = host.get_task_id()
        else:
            task_id = None
        task = create_new_task(host, task_args, audit_entry, sb, task_id=task_id)

        task_provider.check_restrictions(query)

        if task_args.force_new_task:
            task_provider.swap_current_task(task, audit_entry, query, host)
        else:
            task_provider.set_task(task, audit_entry, query)
            task_provider.enqueue()


def schedule_task_from_scenario(host, task_args, get_stages_func):
    task_provider = TaskProvider(host, task_args)
    audit_log_params = task_args.get_task_params()

    query = check_state_and_get_query(
        host,
        task_args.issuer,
        task_args.task_type,
        task_args.ignore_maintenance,
        allowed_states=task_args.allowed_states,
        allowed_statuses=task_args.allowed_statuses,
    )

    task_provider.validate_args()
    with audit_log.create(**audit_log_params) as audit_entry:
        task_provider.log()
        sb = get_stages_func(task_args)
        task = create_new_task(host, task_args, audit_entry, sb)

        task_provider.check_restrictions(query)

        if task_args.force_new_task:
            task_provider.swap_current_task(task, audit_entry, query, host)
        else:
            task_provider.set_task(task, audit_entry, query)
            task_provider.enqueue()


def schedule_task_from_api(host, task_args, get_stages_func):
    task_provider = TaskProvider(host, task_args)
    audit_log_params = task_args.get_task_params()

    query = check_state_and_get_query(
        host,
        task_args.issuer,
        task_args.task_type,
        task_args.ignore_maintenance,
        allowed_states=task_args.allowed_states,
        allowed_statuses=task_args.allowed_statuses,
    )

    task_provider.validate_args()
    with audit_log.create(**audit_log_params) as audit_entry:
        task_provider.log()

        if task_args.use_cloud_post_processor and (task_args.redeploy_after_task or task_args.profile_after_task):
            task_args.cms_action = CmsTaskAction.REDEPLOY

        task_provider.probe_cms()

        sb = get_stages_func(task_args)
        task = create_new_task(host, task_args, audit_entry, sb)
        task_provider.check_restrictions(query)

        if task_args.force_new_task:
            task_provider.swap_current_task(task, audit_entry, query, host)
        else:
            task_provider.set_task(task, audit_entry, query)
            task_provider.enqueue()


class TaskProvider:
    def __init__(self, host, task_args):
        self.host = host
        self.task_args = task_args

    def validate_args(self):
        if self.task_args.with_auto_healing and not self.task_args.monitor_on_completion:
            raise RequestValidationError(
                "Auto healing can be enabled only when host checking after task completion is enabled."
            )

    def log(self):
        log.info("Scheduling host {} {}...".format(self.host.human_id(), self.task_args.operation_type))

    def probe_cms(self):
        if self.task_args.cms_action:
            reject_request_if_needed(
                self.task_args.issuer,
                self.task_args.task_type,
                self.host,
                self.task_args.cms_action,
                self.task_args.ignore_cms,
            )

    def check_restrictions(self, query):
        if self.task_args.operation_restrictions:
            restriction = restrictions.check_restrictions(self.host, self.task_args.operation_restrictions)
            query.update(restrictions__nin=restriction)

    def set_task(self, task, audit_entry, query):
        if not self.host.modify(
            query,
            set__task=task,
            **Host.set_status_kwargs(
                None,
                self.task_args.operation_host_status,
                self.task_args.issuer,
                audit_entry.id,
                reason=self.task_args.reason,
                ticket_key=self.task_args.ticket_key,
                unset_ticket=self.task_args.unset_ticket,
            )
        ):
            raise InvalidHostStateError(
                self.host,
                allowed_states=self.task_args.allowed_states,
                allowed_statuses=self.task_args.allowed_statuses,
            )

    def enqueue(self):
        notifications.on_task_enqueued(self.task_args.issuer, self.host, self.task_args.reason)

    def host_query(self, revision=None, **kwargs):
        kwargs.setdefault("state", self.host.state)
        kwargs.setdefault("status", self.host.status)
        return dict(
            kwargs,
            task__task_id=self.host.task.task_id,
            task__revision=self.host.task.revision if revision is None else revision,
        )

    def swap_current_task(self, task, audit_entry, query, host):
        prev_host = host.copy()

        with HostInterruptableLock(host.uuid, host.tier):
            self.set_task(task, audit_entry, query)
            self.enqueue()

            try:
                on_task_cancelled(self.task_args.issuer, prev_host, self.task_args.reason)
            except Exception:
                log.exception("Failed to execute task cancellation actions")
                # that's sad, but let gc deal with it, we can not do anything here,
                # and neither does the user
