import logging
import typing as tp
from itertools import groupby

import walle.hosts
from sepelib.core import config
from sepelib.core.exceptions import Error, LogicalError
from walle import projects
from walle.application import app
from walle.expert import failure_log
from walle.expert.types import Failure, FAILURE_LIMITS_MAP, get_limit_name
from walle.util.gevent_tools import gevent_idle_iter
from walle.util.limits import TimedLimit, parse_timed_limits

MAX_NUM_OF_HOSTS = 13
MAX_NUM_OF_PROJECTS = 3

log = logging.getLogger(__name__)


class LimitExceeded(Error):
    """Raise when failure limit exceeded."""

    def __init__(self, decision, failure, limits_info):
        self.decision = decision
        self.failure = failure
        self.limits_info = limits_info

        projects_count, objects_with_failure = self._get_projects_and_objects_with_failure()

        if decision.checks:
            error = "Too many failures has occurred for '{}' check ({}, {} projects in total)\n{}\n"
        else:
            error = "Too many failures has occurred with resolution '{}' ({}, {} projects in total)\n{}\n"
        super().__init__(error, failure, limits_info, projects_count, objects_with_failure)

    def _get_projects_and_objects_with_failure(self) -> tp.Tuple[int, str]:
        message_objects = list(sorted(self.limits_info.objects, key=lambda obj: obj["project"]))

        def project_report_line(project: str, objects: tp.List[tp.Dict[str, tp.Any]]) -> str:
            ids = [obj["_id"] for obj in objects]
            hosts_count = len(ids)
            ids_cutted = ids[:MAX_NUM_OF_HOSTS]
            if self.failure in Failure.ALL_RACK:
                obj_str = ", ".join(ids_cutted)
            else:
                fqdns = [
                    host.name for host in gevent_idle_iter(walle.hosts.Host.objects(inv__in=ids_cutted).only("name"))
                ]
                obj_str = ", ".join(fqdns)
            if hosts_count <= MAX_NUM_OF_HOSTS:
                return '{} hosts with such failure in project "{}": {}'.format(hosts_count, project, obj_str)
            else:
                return '{} hosts with such failure in project "{}": {}...'.format(hosts_count, project, obj_str)

        grouped_objects = [(p, list(objs)) for p, objs in groupby(message_objects, lambda obj: obj["project"])]
        grouped_objects.sort(key=lambda e: len(e[1]), reverse=True)
        projects_count = len(grouped_objects)
        object_lines = map(lambda entry: project_report_line(entry[0], entry[1]), grouped_objects[:MAX_NUM_OF_PROJECTS])
        if projects_count <= MAX_NUM_OF_PROJECTS:
            return projects_count, "\n".join(object_lines)
        else:
            return projects_count, "\n".join(object_lines) + "\n...\n"


class GlobalLimitExceeded(LimitExceeded):
    pass


class ProjectLimitExceeded(LimitExceeded):
    pass


class PlotLimitExceeded(LimitExceeded):
    pass


def global_limits(settings):
    def check_global_limits(failures):
        start_time = settings.failure_log_start_time
        limits_conf = config.get_value("automation")
        return check_total_limits(failures, start_time, limits_conf)

    return make_limit_checker(known_failures, check_global_limits, GlobalLimitExceeded)


def project_limits(project_id, project_automation):
    def check_project_limits(failures):
        limits, start_time = project_automation.get_project_limits(project_id)
        return check_total_limits(failures, start_time, limits, project_id)

    return make_limit_checker(known_failures, check_project_limits, ProjectLimitExceeded)


def plot_limits(project_id, project_automation, automation_plot):
    def check_project_limits_for_plot(failures):
        limits, start_time = project_automation.get_project_limits(project_id)
        return check_total_limits(failures, start_time, limits, project_id)

    return make_limit_checker(plot_check_failure(automation_plot), check_project_limits_for_plot, PlotLimitExceeded)


def make_limit_checker(failure_filter, limit_checker, exception_type):
    def check_limits(decision):
        failures = list(filter(failure_filter, decision.failures))
        if not failures:
            return

        result, failure = limit_checker(failures)
        if not result:
            raise exception_type(decision, failure, result.info)

    return check_limits


def known_failures(failure):
    return failure in FAILURE_LIMITS_MAP


def plot_check_failure(automation_plot):
    def plot_check_predicate(failure):
        is_plot_check_failure = failure not in FAILURE_LIMITS_MAP and automation_plot.have_check(failure)

        if not is_plot_check_failure and failure not in FAILURE_LIMITS_MAP:
            # we currently can't have unknown failures except for automation plot checks
            raise LogicalError

        return is_plot_check_failure

    return plot_check_predicate


def check_total_limits(failures, start_time, limits_conf, project_id=None):
    """Result is a structure that acts like a boolean value but
    also carries a description of the limit that was reached.

    Currently return single value regardless of amount of limits, reached or not.
    """

    result = None
    for failure in failures:
        limits = _get_total_limits(get_limit_name(failure), limits_conf, failure)
        result = failure_log.check_total_action_limits(failure, start_time, limits, project_id)
        if not result:
            return result, failure

    return result, None


def _get_total_limits(limit_name, limits_conf, failure) -> list[TimedLimit]:
    action_limits_doc = app.settings().global_timed_limits_overrides.get(limit_name)
    if action_limits_doc:
        action_limits_conf = [action_limits_doc.to_mongo()]
    else:
        try:
            action_limits_conf = limits_conf[limit_name]
        except KeyError:
            log.info("Limits configuration '%s' for failure %s is missing, using defaults", limit_name, failure)
            action_limits_conf = projects.get_default_automation_limits(failure)

    return parse_timed_limits(action_limits_conf)
