"""Contains all logic for acquiring permission for task processing."""

import logging
import re
import typing as tp
from collections import namedtuple

from sepelib.core import config, constants
from sepelib.core.exceptions import Error, LogicalError
from walle import audit_log
from walle.clients.cms import (
    CmsError,
    CmsTaskRejectedError,
    CmsTaskStatus,
    make_maintenance_info,
    get_cms_task_type,
    CmsTaskAction,
)
from walle.clients.tvm import TvmApiError
from walle.clients.utils import strip_api_error
from walle.failure_reports.cms_reports import CmsReportParams
from walle.fsm_stages.common import (
    commit_stage_changes,
    complete_current_stage,
    fail_current_stage,
    generate_stage_handler,
    get_current_stage,
    register_stage,
)
from walle.hosts import Host, HostStatus, HostState, TaskType
from walle.models import timestamp
from walle.projects import is_temporary_unreachable_enabled
from walle.scenario.cms import get_scenario_info_for_cms
from walle.stages import Stages
from walle.statbox.contexts import host_context
from walle.statbox.loggers import cms_logger
from walle.util import misc
from walle.util.limits import nice_period
from walle.util.workdays import HOST_MAINTENANCE_HOURS, WORKING_DAYS, from_timestamp, next_working_hour_timestamp

log = logging.getLogger(__name__)

_LIMITS_CHECK_PERIOD = constants.MINUTE_SECONDS
_CMS_WAIT_TIMEOUT = constants.WEEK_SECONDS

_STATUS_WALLE = "wall-e"
_STATUS_CMS = "cms"
_STATUS_CALENDAR = "calendar"

TaskStatus = namedtuple("TaskStatus", ["status", "reason", "timeout"])


class _TaskLimitsExceededError(Error):
    pass


def _acquire_walle(host):
    """Acquires permission from Wall-E."""

    try:
        _check_limits(host)
    except _TaskLimitsExceededError as e:
        # TODO: use more optimal solution
        # A simple protection from:
        # * a lot of tasks with slow CMS
        # * a lot of tasks from one user that block the whole Wall-E queue
        # * too permissive CMS
        log.info("Delay processing of '%s' task by '%s' for %s: %s", host.status, host.task.owner, host.human_id(), e)
        return commit_stage_changes(host, status_message=e, check_after=_LIMITS_CHECK_PERIOD)

    cms_logger().log(
        cms_name=host.get_project(fields=["cms"]).cms,
        walle_action="walle_acquire_permission",
        ignore_cms=host.task.ignore_cms,
        **host_context(host)
    )

    commit_stage_changes(host, status=_STATUS_CALENDAR, check_now=True)


def _acquire_cms(host: Host):
    if host.name is None or host.task.ignore_cms:
        raise LogicalError

    context_logger = cms_logger(walle_action="cms_acquire_permission", **host_context(host))
    cms_task_id = _set_cms_task_id(host)

    log.info("Getting cms api client for host %s...", host.human_id())
    stage = get_current_stage(host)
    filter_condition = get_filter_condition(stage.get_param("action"))
    cms_api_clients = host.get_cms_clients(context_logger, filter_condition=filter_condition)
    log.info("Got cms api client for host %s", host.human_id())

    if cms_api_clients:
        try:
            return _acquire_cms_api(host, cms_api_clients, cms_task_id)
        except (CmsError, TvmApiError) as e:
            log.info("%s: error acquiring permission from cms: %s", host.human_id(), str(e))
            return commit_stage_changes(host, error=str(e), check_after=_LIMITS_CHECK_PERIOD)

    if not cms_api_clients and not filter_condition:
        raise LogicalError  # project must have correct cms settings
    complete_current_stage(host)


def _acquire_calendar(host):
    """Allows to proceed only during working hours."""
    stage = get_current_stage(host)
    if stage.get_param("workdays", False):
        # First-shot simplest implementation:
        # * working days are Mon thru Fri
        # * working hours are 11:00 thru 17:55
        now = timestamp()
        proceed_at = next_working_hour_timestamp(now, WORKING_DAYS, *HOST_MAINTENANCE_HOURS)

        if proceed_at <= now:
            log.info("%s: it's workday, proceeding.", host.human_id())
            if host.task.ignore_cms:
                complete_current_stage(host)
            else:
                commit_stage_changes(host, status=_STATUS_CMS, check_now=True)
        else:
            log.info(
                "%s: delay processing of '%s' task by '%s' until '%s': workdays rule applied.",
                host.human_id(),
                host.status,
                host.task.owner,
                from_timestamp(proceed_at).isoformat(),
            )

            reason = "Delay processing of '{}' task by '{}' until '{}': workdays rule applied.".format(
                host.status, host.task.owner, from_timestamp(proceed_at).isoformat()
            )
            return commit_stage_changes(host, status_message=reason, check_at=proceed_at)

    elif host.task.ignore_cms:
        complete_current_stage(host)
    else:
        commit_stage_changes(host, status=_STATUS_CMS, check_now=True)


def _acquire_cms_api(host, cms_clients, cms_task_id):
    """Acquires permission from CMS."""
    stage = get_current_stage(host)
    cms_action = stage.get_param("action")
    cms_task_type = get_cms_task_type(host.task.type)
    task_group = stage.get_param("task_group", None)
    cms_comment = stage.get_param("comment", None)
    cms_extra = stage.get_param("extra", None)
    location = _get_location(host)
    failure = stage.get_param("failure", None)
    check_names = stage.get_param("check_names", None)
    failure_type = stage.get_param("failure_type", None)
    maintenance_info = make_maintenance_info(
        cms_task_id, cms_action, [host.name], node_set_id=task_group, comment=cms_comment, issuer=host.task.owner
    )
    scenario_info = get_scenario_info_for_cms(host)

    task_statuses = []
    log.info("Getting task statuses for host %s...", host.human_id())
    for cms in cms_clients:
        log.info("Getting task for host %s from cms...", host.human_id())
        try:
            cms_task = cms.get_task(cms_task_id)
        except CmsError as e:
            cms_task = {"status": CmsTaskStatus.IN_PROCESS, "hosts": [host.name], "message": str(e)}
        except Exception as e:
            cms_task = {
                "status": CmsTaskStatus.IN_PROCESS,
                "hosts": [host.name],
                "message": "Got an unknown exception when trying to get the task from the cms: {}".format(str(e)),
            }

        log.info("Got task for host %s from cms", host.human_id())
        if cms_task is None:
            log.info("Adding task from cms for host %s", host.human_id())
            try:
                cms_task = cms.add_task(
                    cms_task_id,
                    cms_task_type,
                    host.task.owner,
                    cms_action,
                    [host.name],
                    extra=cms_extra,
                    comment=cms_comment,
                    location=location,
                    task_group=task_group,
                    failure=failure,
                    check_names=check_names,
                    failure_type=failure_type,
                    maintenance_info=maintenance_info,
                    scenario_info=scenario_info,
                )
            except CmsTaskRejectedError as e:
                cms_task = {"status": CmsTaskStatus.REJECTED, "message": str(e)}
        elif host.name not in cms_task["hosts"]:
            raise Error(
                "Error in communication with CMS: CMS has created a task for invalid hosts: {} instead of {}.",
                misc.format_long_list_for_logging(cms_task["hosts"]),
                [host.name],
            )

        task_status = cms_task["status"]
        reason = strip_api_error(cms_task.get("message", "")).rstrip(".") or "[no message provided]"

        if task_status == CmsTaskStatus.REJECTED:
            log.info("Failing current stage for host %s", host.human_id())
            fail_current_stage(host, "CMS rejected the request: {}.".format(reason))
            return

        timeout = _get_cms_acquire_timeout(cms)
        task_statuses.append(TaskStatus(status=task_status, reason=reason, timeout=timeout))

    log.info("Got task statuses for host %s", host.human_id())
    if len(task_statuses) == 0 or len(task_statuses) != len(cms_clients):
        raise LogicalError

    if all([task_status.status == CmsTaskStatus.OK for task_status in task_statuses]):
        log.info("Current stage for host %s is completed.", host.human_id())
        host.set_iss_ban_flag()
        complete_current_stage(host)
        return

    log.info("Checking timed out stages for host %s...", host.human_id())
    for task_status in task_statuses:
        if stage.timed_out(task_status.timeout):
            commit_stage_changes(
                host,
                check_after=_LIMITS_CHECK_PERIOD,
                error="CMS has not allowed to perform {} action during given timeout ({}). "
                "Reason: {}".format(cms_action, nice_period(task_status.timeout), task_status.reason),
            )
            return

    log.info("Checked timed out stages for host %s", host.human_id())

    full_reason = "; ".join(task_status.reason for task_status in task_statuses if cms_task_not_ok(task_status))
    commit_stage_changes(
        host,
        status_message="CMS hasn't allowed to process the host yet: {}.".format(full_reason),
        check_after=_LIMITS_CHECK_PERIOD,
    )

    log.info("Finished acquiring cms api for host %s", host.human_id())


def _set_cms_task_id(host):
    if host.task.cms_task_id is None:
        stage = get_current_stage(host)
        force_new = stage.get_param("force_new_cms_task", False)

        if force_new or not host.cms_task_id:
            cms_task_id = host.task.get_cms_task_id()
        else:
            cms_task_id = host.cms_task_id

        host.task.cms_task_id = cms_task_id
        commit_stage_changes(host)

        audit_log.update_payload(host.task.audit_log_id, {"cms_task_id": cms_task_id})

    return host.task.cms_task_id


def _get_cms_acquire_timeout(cms):
    for cms_rp in config.get_value("failure_reports.report_params.cms_report_params"):
        cms_report_params = CmsReportParams(**cms_rp)
        if re.match(cms_report_params.url_match, cms.name):
            return cms_report_params.acquire_timeout

    return _CMS_WAIT_TIMEOUT


def _check_limits(host):
    query = dict(status__in=HostStatus.ALL_TASK, task__type__=host.task.type, task__task_id__lte=host.task.task_id)

    if host.task.type == TaskType.MANUAL:
        config_name = "max_processing_tasks_per_user"
        query.update(task__owner=host.task.owner)
        error = _TaskLimitsExceededError("Too many tasks are enqueued by {}.", host.task.owner)
    elif host.task.type == TaskType.AUTOMATED_HEALING:
        query.update(project=host.project)
        query.update(state__nin=HostState.ALL_IGNORED_LIMITS_COUNTING)
        config_name = "max_processing_healing_tasks_per_project"
        error = _TaskLimitsExceededError("Too many host healing tasks are enqueued for {} project.", host.project)
    elif host.task.type == TaskType.AUTOMATED_ACTION:
        query.update(project=host.project)
        query.update(state__nin=HostState.ALL_IGNORED_LIMITS_COUNTING)
        config_name = "max_processing_automated_tasks_per_project"
        error = _TaskLimitsExceededError("Too many automated tasks are enqueued for {} project.", host.project)
    else:
        raise LogicalError()

    max_processing_tasks = config.get_value("task_processing." + config_name)
    active_tasks = Host.objects(**query).count()

    if active_tasks > max_processing_tasks:
        raise error


def _get_location(host):
    location = {"switch": host.location.switch, "port": host.location.port}

    return location


def get_filter_condition(cms_action: str) -> tp.Optional[tp.Callable]:
    if cms_action == CmsTaskAction.TEMPORARY_UNREACHABLE:
        return is_temporary_unreachable_enabled


def cms_task_not_ok(task_status):
    return task_status.status != CmsTaskStatus.OK


register_stage(
    Stages.ACQUIRE_PERMISSION,
    generate_stage_handler(
        {
            _STATUS_WALLE: _acquire_walle,
            _STATUS_CMS: _acquire_cms,
            _STATUS_CALENDAR: _acquire_calendar,
        }
    ),
    initial_status=_STATUS_WALLE,
)
