"""Host status management."""

import itertools
import logging
from math import ceil

from mongoengine import Q

import walle.tasks
import walle.util.tasks
from sepelib.core import constants
from sepelib.core.exceptions import LogicalError
from walle import audit_log, authorization
from walle.clients import bot
from walle.clients.startrek import StartrekClientError
from walle.clients.startrek import get_tickets_by_query
from walle.errors import InvalidHostStateError, InvalidHostConfiguration
from walle.expert.types import CheckType, CheckStatus
from walle.host_network import HostNetwork
from walle.hosts import Host, HostState, HostStatus, get_host_query, TaskType, HealthStatus
from walle.locks import HostInterruptableLock
from walle.models import timestamp
from walle.projects import Project
from walle.util import notifications
from walle.util.gevent_tools import gevent_idle_iter
from walle.util.host_health import get_failure_reason, get_failure_reason_deprecated

log = logging.getLogger(__name__)

MIN_STATE_TIMEOUT = 30 * constants.MINUTE_SECONDS
"""Minimal state timout. State timeout for shorter periods have no sense and can be harmful."""

MAX_STATE_TIMEOUT = 4 * constants.WEEK_SECONDS
"""Maximum state timeout. We want our cluster to work, not hang in maintenance forever."""

MAX_STARTREK_QUERY_TICKETS_COUNT = 1000
"""Maximum tickets count to include in Startrek query."""

IGNORED_CHECKS = [
    CheckType.NETMON,
    CheckType.WALLE_RACK,
    CheckType.WALLE_RACK_OVERHEAT,
    CheckType.TOR_LINK,
    CheckType.INFINIBAND,
]
"""Checks that should be ignored when leaving probation."""

IGNORED_CHECK_STATUSES = [
    CheckStatus.FAILED,
    CheckStatus.SUSPECTED,
    CheckStatus.MISSING,
    CheckStatus.VOID,
    CheckStatus.STALED,
]
"""Possible check statuses for IGNORED_CHECKS"""

IGNORED_REASONS = {
    get_failure_reason(CheckType.WALLE_MAPPING[CheckType.META], CheckStatus.MISSING),
    get_failure_reason_deprecated(CheckType.WALLE_MAPPING[CheckType.META], CheckStatus.MISSING),
}
"""Reasons that should be ignored when leaving probation."""

STARTREK_CLOSED_TICKETS_QUERY_TEMPLATE = "(Key: {ticket_keys} OR Aliases: {ticket_keys}) AND Resolution: notEmpty()"


def _get_hosts_with_ignored_checks():
    reasons = (
        {
            get_failure_reason(CheckType.WALLE_MAPPING[check], status)
            for check in IGNORED_CHECKS
            for status in IGNORED_CHECK_STATUSES
        }
        .union(
            {
                get_failure_reason_deprecated(CheckType.WALLE_MAPPING[check], status)
                for check in IGNORED_CHECKS
                for status in IGNORED_CHECK_STATUSES
            }
        )
        .union(IGNORED_REASONS)
    )

    hosts = Host.objects(state=HostState.PROBATION, status=HostStatus.READY, health__status__ne=HealthStatus.STATUS_OK)

    for host in gevent_idle_iter(hosts):
        if host.health and host.health.reasons and set(host.health.reasons).issubset(reasons):
            yield host


def _get_all_ready_hosts():
    hosts = Host.objects(state=HostState.PROBATION, status=HostStatus.READY, health__status=HealthStatus.STATUS_OK)

    yield from gevent_idle_iter(hosts)


def _gc_complete_probation():
    for host in itertools.chain(_get_all_ready_hosts(), _get_hosts_with_ignored_checks()):
        _assign_prepared_host(host)


def _assign_prepared_host(host):
    from walle.expert.dmc import is_host_healthy

    if host.health.status == HealthStatus.STATUS_OK and not is_host_healthy(host):
        log.info("Host %s is not healthy, keeping it in probation.", host.name)
        return

    reason = "Host {} has successfully finished preparing.".format(host.name)
    walle.tasks.schedule_setting_assigned_state(
        authorization.ISSUER_WALLE,
        TaskType.AUTOMATED_ACTION,
        host,
        status=HostStatus.READY,
        ignore_maintenance=True,
        monitor_on_completion=False,
        reason=reason,
    )


def _gc_maintenance_timeout():
    """Finds hosts with expired status timeout and switches their status to the requested one."""

    maintenance_timeout_hosts = Host.objects(
        Q(state=HostState.MAINTENANCE, status__in=HostStatus.ALL_STEADY)
        & Q(state_expire__time__lte=timestamp(), state_expire__time__ne=None)
    )

    for host in gevent_idle_iter(maintenance_timeout_hosts):
        _try_to_revert_temporary_status(host)


def _gc_maintenance_ticket_closed():
    """Finds hosts on maintenance with attached closed tickets.
    If timeout has been set for the maintenance state, ignore this behaviour."""
    host_query = Q(state=HostState.MAINTENANCE, status__in=HostStatus.ALL_STEADY, scenario_id=None) & Q(
        state_expire__exists=True, state_expire__time=None
    )

    try:
        closed_tickets = _get_closed_st_tickets(host_query)
    except StartrekClientError as e:
        log.error("Failed to get closed tickets from Startrek: %s.", str(e))
        return
    except KeyError:
        log.error("Startrek schema changed unexpectedly.")
        return

    closed_ticket_hosts = Host.objects(
        host_query & (Q(ticket__in=closed_tickets) | Q(state_expire__ticket__in=closed_tickets))
    )

    for host in gevent_idle_iter(closed_ticket_hosts):
        _try_to_revert_temporary_status(host)


def _get_closed_st_tickets(host_query):
    host_query &= Q(ticket__exists=True, ticket__nin=[None, ""]) | Q(state_expire__ticket__exists=True)
    hosts = Host.objects(host_query).only("ticket", "state_expire__ticket")
    ticket_keys = sorted({host.state_expire.ticket if host.state_expire else host.ticket for host in hosts})
    page_size = MAX_STARTREK_QUERY_TICKETS_COUNT
    tickets = []
    if not ticket_keys:
        return tickets

    page_count = int(ceil(len(ticket_keys) / float(page_size)))

    for page_num in range(page_count):
        keys_page = ticket_keys[page_size * page_num : page_size * (page_num + 1)]
        keys_str = ",".join(keys_page)
        tickets += get_tickets_by_query(STARTREK_CLOSED_TICKETS_QUERY_TEMPLATE.format(ticket_keys=keys_str))

    aliases = itertools.chain.from_iterable([ticket["aliases"] for ticket in tickets if "aliases" in ticket])
    keys = {ticket["key"] for ticket in tickets} | set(aliases)

    return list(keys)


def _try_to_revert_temporary_status(host):
    # Ignores InvalidHostStateError, logs other exceptions
    try:
        _revert_temporary_status(host)
    except InvalidHostStateError:
        pass
    except Exception:
        log.exception("%s: Failed to change host from temporary to target state.", host.human_name())


def _revert_temporary_status(host):
    timeout_status, issuer, reason = _get_fallback_status(host)
    log.info(
        "Switching host %s from maintenance state to assigned with '%s' status...", host.human_id(), timeout_status
    )

    # WALLE-2707 Power on server if project's healing automation is enabled.
    power_on = Project.objects.get(id=host.project).healing_automation.enabled

    walle.tasks.schedule_setting_assigned_state(
        authorization.ISSUER_WALLE,
        TaskType.AUTOMATED_ACTION,
        host,
        status=timeout_status,
        power_on=power_on,
        ignore_maintenance=True,
        monitor_on_completion=False,
        reason=reason,
    )


def _fallback_status_from_expire(host, state_expire):
    timeout_status = state_expire.status
    issuer = state_expire.issuer

    if state_expire.ticket:
        ticket = state_expire.ticket
    else:
        ticket = host.ticket

    if state_expire.time is None:  # Host is leaving maintenance state by ticket
        reason = "The '{}' state requested by {} is expired because the ticket {} has been closed.".format(
            host.state, authorization.get_issuer_name(issuer), ticket
        )
    else:  # Host is leaving maintenance state by timeout
        reason = "The '{}' state requested by {} has timed out.".format(
            host.state, authorization.get_issuer_name(issuer)
        )

    return timeout_status, issuer, reason


def _get_fallback_status(host):
    if host.state_expire:
        return _fallback_status_from_expire(host, host.state_expire)

    elif host.status == HostStatus.default(HostState.MAINTENANCE):
        timeout_status = HostStatus.READY
        reason = "The '{}' status has timed out.".format(HostStatus.default(HostState.MAINTENANCE))
        issuer = host.status_author
    else:
        raise LogicalError

    return timeout_status, issuer, reason


def cancel_task(issuer, host, reason, ignore_maintenance=False, lock_class=HostInterruptableLock):
    status = HostStatus.default(host.state)
    audit_entry = audit_log.on_task_cancelled(
        issuer,
        host.project,
        host.inv,
        host.name,
        host.uuid,
        status,
        host.task.audit_log_id,
        reason=reason,
        scenario_id=host.scenario_id,
    )
    walle.host_status.force_status(
        issuer,
        host,
        status,
        audit_entry=audit_entry,
        ignore_maintenance=ignore_maintenance,
        lock_class=lock_class,
        reason=reason,
    )


def force_status(
    issuer,
    host,
    status,
    audit_entry=None,
    forbidden_statuses=None,
    only_from_current_status_id=False,
    ignore_maintenance=False,
    ticket_key=None,
    reason=None,
    lock_class=HostInterruptableLock,
):
    """Forces the specified status for the host. If host is processing some task, it will be cancelled.

    Attention: Fails with InvalidHostStateError if host changed it's state (not status!). This behaviour must be saved
               because some callers rely on it.
    """

    allowed_states = HostState.ALL
    allowed_steady_statuses = HostStatus.ALL_STEADY
    allowed_statuses = HostStatus.ALL
    allowed_forced_statuses = [HostStatus.default(host.state), HostStatus.DEAD, HostStatus.INVALID]

    if forbidden_statuses:
        allowed_steady_statuses = list(set(allowed_steady_statuses) - set(forbidden_statuses))
        allowed_statuses = list(set(allowed_statuses) - set(forbidden_statuses))

    updated_host = None
    if audit_entry is None:
        audit_entry = audit_log.on_force_host_status(
            issuer, host.project, host.inv, host.name, host.uuid, status, reason=reason, scenario_id=host.scenario_id
        )

    with audit_entry:
        if (
            host.state not in allowed_states
            or host.status not in allowed_statuses
            or status not in allowed_forced_statuses
        ):
            raise InvalidHostStateError(host, allowed_states=allowed_states, allowed_statuses=allowed_statuses)

        if host.status == HostStatus.INVALID:
            result = bot.get_host_info(host.inv)
            if result is None or "name" in result and host.name != result['name']:
                raise InvalidHostConfiguration(
                    "The #{} host has changed its inventory number or name.".format(host.inv)
                )

        host_query = _get_host_query(host, issuer, ignore_maintenance, only_from_current_status_id)
        update_kwargs = _get_status_update_kwargs(host, status, issuer, audit_entry, ticket_key, reason)
        update_network_kwargs = _get_network_status_update_kwargs(status)

        if host.status in allowed_steady_statuses:
            updated_host = Host.objects(status__in=allowed_steady_statuses, **host_query).modify(**update_kwargs)

        if updated_host is None:
            with lock_class(host.uuid, host.tier):
                updated_host = Host.objects(status__in=allowed_statuses, **host_query).modify(
                    unset__task=True, **update_kwargs
                )
                if update_network_kwargs:
                    HostNetwork.objects(uuid=host.uuid).modify(**update_network_kwargs)

                if updated_host is not None and updated_host.status in HostStatus.ALL_TASK:
                    walle.util.tasks.on_task_cancelled(issuer, updated_host, reason)

        if updated_host is None:
            raise InvalidHostStateError(host, allowed_states=allowed_states, allowed_statuses=allowed_statuses)

    if status == HostStatus.INVALID and issuer == authorization.ISSUER_WALLE:
        notifications.on_invalid_host(host, reason)


def _get_host_query(host, issuer, ignore_maintenance, only_from_current_status_id):
    # We must query by both host inventory number and name because in other case we might get in race especially in
    # daemon which syncs the database with BOT and switches hosts to invalid status on errors using this function.
    host_query = get_host_query(issuer, ignore_maintenance, host.state, inv=host.inv, name=host.name)

    if only_from_current_status_id:
        if host.status_audit_log_id is None:
            host_query.update(status_audit_log_id__exists=False)
        else:
            host_query.update(status_audit_log_id=host.status_audit_log_id)

    return host_query


def _get_status_update_kwargs(host, status, issuer, audit_entry, ticket_key, reason):
    update_kwargs = Host.set_status_kwargs(
        host.state, status, issuer, audit_entry.id, reason=reason, confirmed=False, downtime=False
    )

    if ticket_key is None and status == HostStatus.READY:
        update_kwargs["unset__ticket"] = True
    elif ticket_key:
        update_kwargs["set__ticket"] = ticket_key

    return update_kwargs


def _get_network_status_update_kwargs(status):
    kwargs = dict()
    if status == HostStatus.INVALID:
        kwargs.update(unset__active_mac=True, unset__active_mac_source=True, unset__active_mac_time=True)
    return kwargs
