"""Collection of classes that represent error reports."""

import datetime
import logging
import re
from collections import defaultdict, Counter

import walle.util.host_health
from sepelib.core import constants
from walle import audit_log
from walle.constants import CMS_REJECT_RE
from walle.expert import juggler
from walle.expert.decisionmakers import load_decision_makers
from walle.expert.types import CheckType, CheckStatus
from walle.failure_reports.base import ReportSection, ReportedHost, GroupingSectionFormatter
from walle.hbf_drills import drills_cache
from walle.hosts import HostStatus, Host, HostState, HostLocation, HostOperationState
from walle.models import timestamp
from walle.stages import Stage, get_by_uid
from walle.util.mongo import SECONDARY_LOCAL_DC_PREFERRED
from walle.util.workdays import to_timestamp
from walle.util.host_health import get_failure_reason, get_failure_reason_deprecated

MAINTENANCE_STALE_TIME = 2 * 4 * constants.WEEK_SECONDS

log = logging.getLogger(__name__)


class ReportContents:
    def __init__(self, formatter):
        self._report_sections = []
        self._lost_hosts = {}

        self.formatter = formatter

    def add_section(self, report_section):
        """Add a section to report contents.
        :type report_section: ReportSection
        """
        self._report_sections.append(report_section)

    def text(self):
        return self.formatter.format(self._report_sections)

    def report_hosts(self):
        """Return list of all hosts from this report. Solved hosts have "solved":True field."""

        report_hosts = []
        for section in self._report_sections:
            for host in section.problem_hosts.values():
                report_hosts.append(host)

            for host in section.solved_hosts.values():
                report_hosts.append(host)

        for host in self._lost_hosts.values():
            report_hosts.append(host)

        return report_hosts

    def merge_hosts(self, previous_hosts):
        """Merge broken hosts from previous run with current list of broken hosts.
        Mark some hosts from previous run as solved and discard hosts that was not present in previous run.

        :param previous_hosts - dict of inv -> host data from previous run
        :type previous_hosts: dict
        """

        section_hosts = set()
        for section in self._report_sections:
            section.merge_hosts(previous_hosts)

            section_hosts.update(section.problem_hosts)
            section_hosts.update(section.solved_hosts)

        lost_hosts = previous_hosts.keys() - section_hosts
        for inv in lost_hosts:
            host = previous_hosts[inv]
            host.solved = True
            host.solve_timestamp = timestamp()
            self._lost_hosts[inv] = host

    def empty(self):
        """Return True if report does not contain any non-empty sections."""
        return not any(section.has_problem_hosts() for section in self._report_sections)


class SectionDataCollector:
    def gather_data(self, projects):
        raise NotImplementedError

    @classmethod
    def mk_section(cls, title):
        formatter = GroupingSectionFormatter(title)
        return ReportSection(name=cls.__name__, formatter=formatter)

    @staticmethod
    def init_reported_host(inv, host_data, reason):
        tickets = [host_data.get("ticket")]
        tickets.extend(_get_stage_tickets(host_data))
        tickets = list(filter(None, tickets)) or None

        return ReportedHost(
            inv=inv,
            name=host_data["name"],
            host_uuid=host_data["_id"],
            status=host_data["status"],
            project=host_data["project"],
            tickets=tickets,
            reason=reason,
            report_timestamp=timestamp(),
        )

    @staticmethod
    def hosts_cursor(projects, query, fields):
        collection = Host.get_collection()
        query["project"] = {"$in": [p.id for p in projects]}

        fields_set = set(fields)
        if "ticket" not in fields_set:
            fields += ("ticket",)

        if not ("task" in fields_set or {"task.stages", "task.stage_uid"} <= fields_set):
            fields += ("task.stages", "task.stage_uid")

        return collection.find(query, fields)


class TopProblematicHosts(SectionDataCollector):
    """Hosts that have more than # problems last week."""

    __automated_actions_cache = None

    def __init__(self, n_problems_a_week=3):
        self.problematic_if_n_problems = n_problems_a_week
        super().__init__()

    @classmethod
    def find_hosts(cls, projects):
        fields = ("inv", "name", "project", "status", "task.owner")
        query = {
            "state": {"$ne": HostState.MAINTENANCE},
            "$or": [{"task": {"$exists": False}}, {"task.owner": "wall-e"}],
        }

        return cls.hosts_cursor(projects, query, fields)

    def gather_data(self, projects):
        automated_actions = self.__get_automated_actions()
        section = self.mk_section("has tried to repair the following hosts too many times for the last 7 days")

        for host in self.find_hosts(projects):
            processed_actions = Counter(action for action in automated_actions[host["inv"]]).most_common()

            if processed_actions and processed_actions[0][1] > self.problematic_if_n_problems:
                reason = self.host_failure_reason(processed_actions, section.formatter)
                problem_host = self.init_reported_host(host.pop("inv"), host, reason=reason)
                section.add_problem_host(problem_host)

        return section

    @staticmethod
    def host_failure_reason(actions, formatter):
        actions = sorted(actions, key=lambda x: x[0])
        reason = ", ".join("{}: {} times".format(action, count) for action, count in actions)

        return formatter.host_failure_reason(reason)

    @classmethod
    def __get_automated_actions(cls):
        if cls.__automated_actions_cache is None:
            # fetch all actions from previous week till today 7:00 am.
            # Time interval is fixed because we do not want list of hosts to change between runs.

            actions = defaultdict(list)
            collection = audit_log.LogEntry.get_collection(read_preference=SECONDARY_LOCAL_DC_PREFERRED)
            query = {
                "issuer": "wall-e",
                "type": {"$in": audit_log.AUTOMATION_TYPES},
                "time": {
                    "$gte": to_timestamp(datetime.date.today() - datetime.timedelta(days=7)),
                    "$lte": to_timestamp(datetime.datetime.now().replace(hour=7, minute=0, second=0, microsecond=0)),
                },
            }

            for entry in collection.find(query, ("host_inv", "type")):
                actions[entry["host_inv"]].append(entry["type"])

            cls.__automated_actions_cache = actions

        return cls.__automated_actions_cache


class InfrastructureProblems(SectionDataCollector):
    """Report with hosts that have problems with rack or switch, and are also unreachable."""

    # NB: Wall-E is learning to report rack failures to ITDC, so this report is going to be obsolete in a short time.
    __decision_makers_cache = {}

    @classmethod
    def find_hosts(cls, projects):
        reasons = [
            walle.util.host_health.get_failure_reason(check_type, check_status)
            for check_type in CheckType.ALL_INFRASTRUCTURE
            for check_status in (CheckStatus.FAILED, CheckStatus.SUSPECTED)
        ]

        reasons += [
            walle.util.host_health.get_failure_reason_deprecated(check_type, check_status)
            for check_type in CheckType.ALL_INFRASTRUCTURE
            for check_status in (CheckStatus.FAILED, CheckStatus.SUSPECTED)
        ]

        fields = ("inv", "name", "status", "project", "health.reasons", "ips", "location")
        query = {
            "status": HostStatus.READY,
            "health.reasons": {"$in": reasons},
        }
        # exclude hosts involved in HBF drills
        drills_col = drills_cache.get()
        hosts = (
            host for host in cls.hosts_cursor(projects, query, fields) if not cls._host_is_in_drill(host, drills_col)
        )
        return hosts

    def gather_data(self, projects):
        section = self.mk_section("has hosts with rack and/or switch problems")

        hosts = list(self.find_hosts(projects))
        host_health = self.__load_host_health(projects, hosts)
        for host in hosts:
            # only include hosts that are not available at all
            reasons = set(host["health"]["reasons"])
            if ({get_failure_reason("ssh", "failed"), get_failure_reason("unreachable", "failed")} - reasons) and (
                {
                    get_failure_reason_deprecated("ssh", "failed"),
                    get_failure_reason_deprecated("unreachable", "failed"),
                }
                - reasons
            ):
                continue
            # keep only rack and switch failures, skip dc and queue failures
            if {
                get_failure_reason("rack", "failed"),
                get_failure_reason_deprecated("rack", "failed"),
            } & reasons or self._is_switch_failure(host_health[host["name"]]):
                reason = self.host_failure_reason(host["health"]["reasons"], section.formatter)
                problem_host = self.init_reported_host(host.pop("inv"), host, reason=reason)
                section.add_problem_host(problem_host)

        return section

    @staticmethod
    def host_failure_reason(failure_reasons, formatter):
        formatted_failure_reasons = []
        format_reason = formatter.host_failure_reason

        for reason in failure_reasons:
            if reason.startswith("switch") or reason.startswith("rack"):
                # bold + red highlight
                formatted_failure_reasons.append("!!(red)**{}**!!".format(format_reason(reason)))
            else:
                formatted_failure_reasons.append(format_reason(reason))

        return ", ".join(sorted(formatted_failure_reasons))

    @classmethod
    def __load_decision_makers(cls, projects):
        projects_ids = {p.id for p in projects}
        missing_dm_pids = projects_ids - set(cls.__decision_makers_cache)

        if missing_dm_pids:
            cls.__decision_makers_cache.update(load_decision_makers(missing_dm_pids))

        return cls.__decision_makers_cache

    @classmethod
    def __load_host_health(cls, projects, hosts):
        decision_makers = cls.__load_decision_makers(projects)
        return juggler.get_health_for_hosts({host["name"]: decision_makers[host["project"]] for host in hosts})

    @staticmethod
    def _is_switch_failure(host_health: juggler.HostHealth):
        (health, current_reasons, _, _) = host_health
        netmon_check = current_reasons[CheckType.NETMON]

        return netmon_check.get("metadata", {}).get("switch", {}).get("status") in {
            CheckStatus.FAILED,
            CheckStatus.SUSPECTED,
        }

    @staticmethod
    def _host_is_in_drill(host_dict, drills_col):
        host_dict = dict(host_dict)
        host_dict["uuid"] = host_dict.pop("_id")
        host_dict["location"] = HostLocation(**host_dict["location"])
        host = Host(**host_dict)

        drill_reason = drills_col.get_host_inclusion_reason(host)
        return bool(drill_reason)


class HostsByStatusCollector(SectionDataCollector):
    # class variables should be overridden in child classes
    section_title = None

    request_fields = ("inv", "name", "project", "status", "status_audit_log_id")
    hosts_query = None

    def find_hosts(self, projects):
        return [host for host in self.hosts_cursor(projects, self.hosts_query, self.request_fields)]

    def gather_data(self, projects):
        section = self.mk_section(self.section_title)
        format_reason = section.formatter.host_failure_reason

        failed_hosts = self.find_hosts(projects)
        audit_log_ids = [host["status_audit_log_id"] for host in failed_hosts if "status_audit_log_id" in host]
        reasons = _get_event_reasons(audit_log_ids, with_reason=True, with_error=True)

        for host in failed_hosts:
            try:
                reason = format_reason(reasons[host["status_audit_log_id"]])
            except KeyError:
                continue
            section.add_problem_host(self.init_reported_host(host.pop("inv"), host, reason=reason))

        return section


class DeadHosts(HostsByStatusCollector):
    """Report with hosts that have tasks failed."""

    request_fields = HostsByStatusCollector.request_fields + ("status_reason",)
    hosts_query = {"status": HostStatus.DEAD}
    section_title = "has dead hosts"

    def __init__(self, cms_reported_projects=None):
        self.cms_reported_project_ids = {p.id for p in cms_reported_projects} if cms_reported_projects else set()

    def find_hosts(self, projects):
        cms_re = re.compile(CMS_REJECT_RE)
        cursor = self.hosts_cursor(projects, self.hosts_query, self.request_fields)

        hosts = []
        for host in cursor:
            if not (host['project'] in self.cms_reported_project_ids and cms_re.match(host.get("status_reason", ""))):
                hosts.append(host)

        return hosts


class InvalidHosts(HostsByStatusCollector):
    """Report hosts with status == invalid"""

    hosts_query = {
        "$or": [
            {"status": HostStatus.INVALID, "state": {"$ne": HostState.MAINTENANCE}},
            {
                "status": HostStatus.INVALID,
                "state": HostState.MAINTENANCE,
                "operation_state": {"$ne": HostOperationState.DECOMMISSIONED},
            },
        ]
    }
    section_title = "has invalid hosts"


class ErrorHosts(SectionDataCollector):
    """Report with hosts that have any problems."""

    def find_hosts(self, projects):
        # this one is going to be a bit flaky, because 1. errors may be temporary
        # 2. errors do not always persist between retries.

        fields = ("inv", "name", "project", "status", "task.error")
        query = {"task.error": {"$exists": True}}

        return self.hosts_cursor(projects, query, fields)

    def gather_data(self, projects):
        section = self.mk_section("has error hosts")
        format_reason = section.formatter.host_failure_reason

        for host in self.find_hosts(projects):
            reason = format_reason(host["task"]["error"])
            problem_host = self.init_reported_host(host.pop("inv"), host, reason=reason)
            section.add_problem_host(problem_host)

        return section


class UnreachableHosts(SectionDataCollector):
    """Report with hosts that have any problems."""

    def find_hosts(self, projects):
        fields = ("inv", "name", "project", "status", "health.reasons")
        query = {
            "status": HostStatus.READY,
            "health.reasons": {
                "$in": [
                    get_failure_reason("ssh", "failed"),
                    get_failure_reason("unreachable", "failed"),
                    get_failure_reason_deprecated("ssh", "failed"),
                    get_failure_reason_deprecated("unreachable", "failed"),
                ]
            },
        }
        return self.hosts_cursor(projects, query, fields)

    def gather_data(self, projects):
        section = self.mk_section("has unreachable hosts")

        for host in self.find_hosts(projects):
            reason = self.host_failure_reason(host["health"]["reasons"], section.formatter)
            problem_host = self.init_reported_host(host.pop("inv"), host, reason=reason)
            section.add_problem_host(problem_host)

        return section

    @staticmethod
    def host_failure_reason(reasons, formatter):
        reason = "The host is unreachable: {}.".format(", ".join(reasons))
        return formatter.host_failure_reason(reason)


class BrokenDnsHosts(SectionDataCollector):
    """Report with hosts that have DNS problems."""

    def find_hosts(self, projects):
        fields = ("inv", "name", "project", "status", "messages")
        query = {
            "status": HostStatus.READY,
            "messages.dns_fixer": {"$exists": True},
        }
        return self.hosts_cursor(projects, query, fields)

    def gather_data(self, projects):
        section = self.mk_section("has hosts with DNS problems")

        for host in self.find_hosts(projects):
            reason = self.host_failure_reason(
                [hm["message"] for hm in host["messages"]["dns_fixer"]], section.formatter
            )
            problem_host = self.init_reported_host(host.pop("inv"), host, reason=reason)
            section.add_problem_host(problem_host)

        return section

    @staticmethod
    def host_failure_reason(reasons, formatter):
        reason = "The host has DNS problems: {}.".format(", ".join(reasons))
        return formatter.host_failure_reason(reason)


class OldMaintenanceTicketHosts(SectionDataCollector):
    """Report with hosts that have been in maintenance without timeout
    and the associated ticket has not been closed for a long time."""

    def find_hosts(self, projects):
        fields = ("inv", "name", "uuid", "project", "status", "ticket", "state_time")
        query = {
            "state": HostState.MAINTENANCE,
            "state_time": {"$lt": timestamp() - MAINTENANCE_STALE_TIME},
            "state_expire.time": {"$exists": False},
            "ticket": {"$exists": True},
        }

        return self.hosts_cursor(projects, query, fields)

    def gather_data(self, projects):
        section = self.mk_section("has old maintenance hosts")
        for host in self.find_hosts(projects):
            reason = self.host_failure_reason(host, section.formatter)
            problem_host = self.init_reported_host(host.pop("inv"), host, reason=reason)
            section.add_problem_host(problem_host)
        return section

    @staticmethod
    def host_failure_reason(host, formatter):
        reason = "The ticket {} is not being closed for too long (in maintenance since {}).".format(
            host["ticket"], datetime.datetime.fromtimestamp(host["state_time"]).isoformat()
        )
        return formatter.host_failure_reason(reason)


def _get_event_reasons(entry_ids, with_reason=False, with_error=False):
    """
    Returns a human-friendly reason for some event that has been caused by an action described by the specified
    audit log entry.
    """

    fields = ["issuer", "type", "status"]
    if with_reason:
        fields.append("reason")
    if with_error:
        fields.append("error")

    collection = audit_log.LogEntry.get_collection(read_preference=SECONDARY_LOCAL_DC_PREFERRED)
    entry_list = collection.find({"_id": {"$in": entry_ids}}, fields)

    reasons = {}
    for entry in entry_list:
        issuer = "Wall-E" if entry["issuer"] == "wall-e" else entry["issuer"]
        reason = "{type} by {issuer} ({status})".format(issuer=issuer, type=entry["type"], status=entry["status"])
        if with_error and "error" in entry:
            reason += ": " + entry["error"]
        elif with_reason and "reason" in entry:
            reason += ": " + entry["reason"]
        reasons[entry["_id"]] = reason

    return reasons


def _get_stage_tickets(host):
    if "task" in host and host["task"].get("stage_uid") is not None:
        stages = [Stage(**stage) for stage in host.get("task", {}).get("stages", [])]
        current_stage = get_by_uid(stages, host["task"]["stage_uid"])
        return current_stage.get_data("tickets", [])

    return []
