"""Set up Juggler to provide all necessary info for Wall-E: aggregates, transport, etc."""

import contextlib
import copy
import dataclasses
import logging
import re
import typing as tp
from collections import defaultdict, namedtuple

import xxhash
from juggler_sdk import JugglerApi, Check, Child, FlapOptions, NotificationOptions

from sepelib.core import config
from walle.expert.constants import JUGGLER_AGGREGATES_SYNC_PERIOD
from walle.expert.decisionmakers import get_decision_maker
from walle.expert.types import CheckType
from walle.hosts import Host, HostLocation, HostState, HostStatus
from walle.projects import Project
from walle.stats import stats_manager as stats
from walle.util import mongo
from walle.util.gevent_tools import gevent_idle_iter
from walle.util.juggler import get_stand_uid, get_aggregate_name
from walle.util.misc import StopWatch, add_interval_job

log = logging.getLogger(__name__)
_Location = namedtuple("Location", ["queue", "rack", "shard_id"])

Group = namedtuple('Group', ["project", "queue", "rack"])


@dataclasses.dataclass
class GroupFilter:
    queue: str
    rack: str
    prj: list[str]

    def to_juggler_request(self) -> str:
        extra_filters = sorted(config.get_value("juggler.group_filters", {}).items())
        location_filters = [("queue", self.queue), ("rack", self.rack)]
        project_filters = [("prj", project_id) for project_id in self.prj]
        return "&".join([f"{k}={v}" for k, v in location_filters + project_filters + extra_filters])

    def get_infiniband_hosts(self) -> tp.Iterable[Host]:
        yield from Host.objects(
            project__in=self.prj,
            location__short_queue_name=self.queue,
            location__rack=self.rack,
            infiniband_info__exists=True,
        ).only("name", "infiniband_info.ports")


def start(scheduler, partitioner, shards_num):
    partitioner.start()

    def _job():
        shards = partitioner.get_numeric_shards(shards_num)
        try:
            _reinforce_aggregates(shards)
        except Exception:
            log.exception("Failed to create Juggler checks for #%s shards: %s", _shards_str(shards))

    add_interval_job(scheduler, _job, name="Juggler aggregates", interval=JUGGLER_AGGREGATES_SYNC_PERIOD)


def _reinforce_aggregates(shards: list[mongo.MongoPartitionerShard]):
    shards_str = _shards_str(shards)
    log.info("Sync Juggler checks for #%s shard...", shards_str)
    stopwatch = StopWatch()
    client = _get_juggler_client()
    marks = _get_marks(shards)

    try:
        with contextlib.ExitStack() as stack:
            for shard in shards:
                stack.enter_context(shard.lock)

            for check in _configured_project_checks(shards, marks):
                log.debug(
                    "Upserting juggler check: %s:%s/[%s]",
                    check.host,
                    check.service,
                    "|".join([child.host for child in check.children]),
                )
                client.upsert_check(check)

                stats.increment_counter(("juggler", "upsert_checks", "upserted"))
                stats.add_sample(("juggler", "upsert_checks", "upsert_time"), stopwatch.split())

            stats.add_sample(("juggler", "upsert_checks", "upsert_total_time"), stopwatch.get())

            client.cleanup(marks.values())
        stats.add_sample(("juggler", "upsert_checks", "cleanup_time"), stopwatch.split())
        stats.add_sample(("juggler", "upsert_checks", "total_time"), stopwatch.get())
    except Exception:
        log.exception("Juggler aggregates sync for shards #%s failed.", shards_str)
    else:
        log.info("Finished syncing Juggler checks for #%s shards.", shards_str)


def _fetch_project_locations_from_mongodb():
    queue_field = Host.location.db_field + "." + HostLocation.short_queue_name.db_field
    rack_field = Host.location.db_field + "." + HostLocation.rack.db_field

    host_collection = mongo.MongoDocument.for_model(Host)
    hosts = host_collection.find(
        {
            queue_field: {"$exists": True},
            rack_field: {"$exists": True},
            # Juggler doesn't want us to create aggregates for projects consisting of hosts they don't monitor
            Host.state.db_field: {"$ne": HostState.FREE},
            Host.status.db_field: {"$ne": HostStatus.INVALID},
        },
        {
            "_id": False,
            "project": True,
            "location.short_queue_name": True,
            "location.rack": True,
        },
        read_preference=mongo.SECONDARY_LOCAL_DC_PREFERRED,
    )

    return {Group(host.project, host.location.short_queue_name, host.location.rack) for host in hosts}


def _get_projects_and_locations(shards):
    """Return dict of projects that have hosts in current shards' locations
    with project's locations as values.
    """
    shard_ids = {s.id for s in shards}
    total_shards = config.get_value("juggler.shards_num")

    projects = defaultdict(list)
    for item in _fetch_project_locations_from_mongodb():
        project_id = item.project
        queue = item.queue
        rack = item.rack

        shard_id = str(xxhash.xxh64_intdigest(f"{queue}-{rack}".encode()) % total_shards)
        if shard_id not in shard_ids:
            continue

        location = _Location(queue, rack, shard_id)

        fail_reason = _validate_location(location)
        if fail_reason is not None:
            log.error("Location %s did not pass validation: %s", location, fail_reason)
            continue

        projects[project_id].append(location)

    return projects


def _validate_location(location):
    # validate rack name
    rack_re = r"^(?! )[a-zA-Z0-9 ._-]+(?<! )$"
    if not re.match(rack_re, location.rack):
        return "rack name {} does not match regexp '{}'".format(location.rack, rack_re)


def _configured_project_checks(shards: list[mongo.MongoPartitionerShard], marks):
    """Yield juggler_sdk.Check objects for every check configured for any project that have hosts in current shards.
    Every check is an aggregate for hosts in a rack.
    Some checks are only enabled for a few hosts in a rack, so we also filter by project id.
    """
    project_to_locations = _get_projects_and_locations(shards)
    check_to_projects = _get_enabled_checks(list(project_to_locations.keys()))

    for check_name, projects in check_to_projects.items():
        # create aggregates for projects
        group_locations = _get_group_locations(projects, project_to_locations)
        generate_check = _get_juggler_check_generator(get_stand_uid(), check_name)

        for check_location, location_projects in group_locations.items():
            aggregate_name, group_filter = _get_aggregate_params(check_location, location_projects)
            if check := generate_check(aggregate_name, group_filter, marks[check_location.shard_id]):
                yield check


def _get_juggler_check_generator(stand_uid, check_name):
    mixins = config.get_value("juggler.mixins")
    generic_passive = config.get_value("juggler.checks._generic_passive")
    check_config = config.get_value("juggler.checks.{}".format(check_name), default=generic_passive).copy()

    common = {
        "tags": [stand_uid],
        "notifications": _notification_settings(),
        "service": check_name,
    }
    common_check_args = _merge(common, check_config, *[mixins[m] for m in check_config.pop("mixins", [])])

    if "flap" in common_check_args:
        common_check_args["flaps_config"] = FlapOptions(**common_check_args.pop("flap"))

    # per-location check generator
    def _juggler_check(aggregate_name, group_filter: GroupFilter, mark) -> tp.Optional[Check]:
        children = _children_groups(group_filter, check_name)
        if not children:
            return
        check = {
            "mark": mark,
            "host": aggregate_name,
            "children": children,
        }

        check_args = _merge(check, common_check_args)

        return Check(**check_args)

    return _juggler_check


def _get_group_locations(projects, locations):
    """
    Get list of projects and map of project hosts' locations and return a map of projects per location.
    """

    group_locations = defaultdict(list)
    for project_id in projects:
        for location in locations[project_id]:
            group_locations[location].append(project_id)

    return group_locations


def _get_aggregate_params(location, projects):
    """Get location and list of projects and create juggler group filter string."""
    # extra_projects are mainly for prestable installation:
    # add this projects to all groups to make rack check more reliable
    # and to make WALLE@PROD groups work as expected (project id might differ between Wall-E prod and prestable)
    extra_projects = config.get_value("juggler.group_force_projects", [])

    group_filter = GroupFilter(location.queue, location.rack, sorted(set(projects + extra_projects)))
    aggregate_name = get_aggregate_name(location.queue, location.rack)
    return aggregate_name, group_filter


def _get_enabled_checks(project_ids):
    """Return a dict which keys are checks that enabled for some projects in Wall-E
    and values are lists of projects for which check is enabled.
    """

    all_non_juggler = set(CheckType.ALL) - set(CheckType.ALL_JUGGLER)
    """We don't know what checks projects use, but we have to create them in Juggler.
    We only know what checks we don't need to create in Juggler, these are them."""

    checks_map = defaultdict(list)
    for project in gevent_idle_iter(_fetch_projects(project_ids)):
        for check in get_decision_maker(project).checks_to_configure() - all_non_juggler:
            checks_map[check].append(project.id)

    return checks_map


def _fetch_projects(project_ids):
    return Project.objects(id__in=project_ids)


def _get_juggler_client():
    api_kwargs = config.get_value("juggler.client_kwargs")
    if not api_kwargs.get("mark"):
        api_kwargs["mark"] = get_stand_uid()  # default to the same uuid as used as a source for downtimes.

    return JugglerApi(**api_kwargs)


def _get_marks(shards):
    stand_uid = get_stand_uid()
    return {shard.id: "{}_{}".format(stand_uid, shard.id) for shard in shards}


def _shards_str(shards):
    return ",#".join(map(str, shards))


def _notification_settings():
    return [NotificationOptions(**config.get_value("juggler.notifications.push"))]


def _children_groups(group_filter: GroupFilter, service):
    group_type = config.get_value("juggler.group_type")
    group_name = config.get_value("juggler.group_name")

    if service == CheckType.IB_LINK:
        result = []
        for host in group_filter.get_infiniband_hosts():
            result.extend(Child(host=host.name, service=service, instance=port) for port in host.infiniband_info.ports)
        return result
    else:
        # TODO(rocco66): why shoud we use that special `WALLE` juggler group filter here?
        #                why do not materialize all hosts here?
        group_spec = group_name + "@" + group_filter.to_juggler_request()
        return [Child(group_type=group_type, host=group_spec, service=service)]


def _merge(*dicts):
    """Merge dicts into the new dict without changing contents of the original dicts."""
    destination = {}
    for source in dicts:
        _merge_two(destination, source)
    return destination


def _merge_two(destination, source):
    """Take two dicts and merge all keys of the source dict into the destination dict, recursively."""
    for key in source:
        if key in destination:
            if destination[key] == source[key]:
                pass
            elif isinstance(destination[key], dict) and isinstance(source[key], dict):
                _merge_two(destination[key], source[key])
            elif isinstance(destination[key], list) and isinstance(source[key], list):
                destination[key].extend(source[key])
            # don't merge other types, tuple included, just overwrite.
            elif isinstance(source[key], dict) or isinstance(source[key], list):
                destination[key] = copy.copy(source[key])
            else:
                destination[key] = source[key]
        elif isinstance(source[key], dict) or isinstance(source[key], list):
            destination[key] = copy.copy(source[key])
        else:
            destination[key] = source[key]
