"""Switch host state to maintenance."""

import logging

import walle.tasks
from sepelib.core.exceptions import Error
from walle.fsm_stages.common import register_stage, get_current_stage, complete_current_stage
from walle.hosts import Host, HostState, HostOperationState, StateExpire
from walle.stages import Stages
from walle.util.misc import fix_mongo_set_kwargs

log = logging.getLogger(__name__)


def _set_maintenance(host):
    stage = get_current_stage(host)
    expire = StateExpire(
        ticket=stage.get_param("ticket_key", host.ticket),
        time=stage.get_param("timeout_time", None),
        status=stage.get_param("timeout_status"),
        issuer=host.task.owner,
    )
    operation_state = stage.get_param("operation_state", HostOperationState.OPERATION)

    _set_state(
        host,
        HostState.MAINTENANCE,
        expire=expire,
        set__cms_task_id=host.task.cms_task_id,
        set__operation_state=operation_state,
    )


def _set_assigned(host):
    _set_state(host, HostState.ASSIGNED, set__cms_task_id=None, set__operation_state=HostOperationState.OPERATION)


def _set_probation(host):
    _set_state(host, HostState.PROBATION, set__cms_task_id=host.task.cms_task_id)


def _set_state(host, state, expire=None, **host_kwargs):
    stage = get_current_stage(host)

    host_state_kwargs = Host.set_state_kwargs(
        state, host.task.owner, host.task.audit_log_id, expire=expire, reason=stage.get_param("reason", None)
    )

    updated = host.modify(
        walle.tasks.host_query(host),
        inc__task__revision=1,
        **fix_mongo_set_kwargs(**merge_kwargs(host_state_kwargs, host_kwargs))
    )

    if not updated:
        raise Error("Unable to commit host state change: it doesn't have an expected state already.")

    complete_current_stage(host)


def merge_kwargs(kwargs, *other_kwargs):
    if other_kwargs:
        return dict(kwargs, **merge_kwargs(*other_kwargs))
    else:
        return kwargs


register_stage(Stages.SET_MAINTENANCE, _set_maintenance)
register_stage(Stages.SET_ASSIGNED, _set_assigned)
register_stage(Stages.SET_PROBATION, _set_probation)
