import collections
import itertools
from functools import partial

from walle.clients import abc
from walle.hosts import Host
from walle.scenario.common import get_inv_to_tags_map
from walle.scenario.constants import (
    HostScenarioStatus,
    SchedulerName,
    FixedMaintenanceApproversLogins,
    MaintenanceGroupProjectTags,
)
from walle.scenario.data_storage.types import MaintenanceApproversGroup
from walle.scenario.scenario import ScenarioHostState
from walle.scenario.utils import BaseRegistry


class SchedulerRegistry(BaseRegistry):
    ITEMS = {}


@SchedulerRegistry.register(SchedulerName.DATACENTER)
class DatacenterScheduler:
    @staticmethod
    def schedule(hosts_info):
        target_invs = [host_info.inv for host_info in hosts_info.values()]
        hosts = Host.objects(inv__in=target_invs).only("inv", "location.datacenter")
        datacenter_map = collections.defaultdict(partial(next, itertools.count()))
        inv_to_dc_map = {host.inv: datacenter_map[host.location.datacenter] for host in hosts}

        for host_info in hosts_info.values():
            set_group(host_info, inv_to_dc_map[host_info.inv])

        return hosts_info


@SchedulerRegistry.register(SchedulerName.ALL)
class AllHostsScheduler:
    @staticmethod
    def schedule(hosts_info):
        for host_info in hosts_info.values():
            set_group(host_info, 0)

        return hosts_info


@SchedulerRegistry.register(SchedulerName.ONLY_PREVIOUSLY_ACQUIRED)
class AllPreviouslyAcquiredHostsScheduler:
    @staticmethod
    def schedule(hosts_info):
        for host_info in hosts_info.values():
            if host_info.is_acquired:
                set_group(host_info, 0)
            else:
                set_group(host_info, -1)

        return hosts_info


class MaintenanceApproversSchedulerException(Exception):
    pass


@SchedulerRegistry.register(SchedulerName.MAINTENANCE_APPROVERS)
class MaintenanceApproversScheduler:
    def __init__(self, hosts_info, data_storage):
        self._hosts_info = hosts_info
        self._data_storage = data_storage

        self._inv_tags_map = get_inv_to_tags_map([host.inv for host in hosts_info.values()])

        self._groups = []
        self._next_group_id = 0

        self._default_approvers_group_id = None
        self._groups_ids_created_from_known_project_tag = {}  # {"known_tag": group_id} map

    def schedule(self):
        for host_info in self._hosts_info:
            group_id = self._get_host_group_id(self._hosts_info[host_info].inv)
            set_group(self._hosts_info[host_info], group_id)
        self._data_storage.write(self._groups)
        return self._hosts_info

    def _get_host_group_id(self, inv: int) -> int:
        # WALLE-4081 Order of known tags matters.
        for tag in MaintenanceGroupProjectTags.ALL:
            if tag in self._inv_tags_map[inv]:
                return self._get_approvers_group_id_for_known_project_tag(tag)

        return self._get_default_approvers_group_id()

    def _get_approvers_group_id_for_known_project_tag(self, tag: str) -> int:
        if tag not in self._groups_ids_created_from_known_project_tag:
            self._groups.append(
                MaintenanceApproversGroup(
                    self._next_group_id, tag, self._get_approvers_logins_by_known_project_tag(tag)
                )
            )
            self._groups_ids_created_from_known_project_tag[tag] = self._next_group_id
            self._next_group_id += 1
        return self._groups_ids_created_from_known_project_tag[tag]

    def _get_default_approvers_group_id(self) -> int:
        if self._default_approvers_group_id is not None:
            return self._default_approvers_group_id
        self._groups.append(
            MaintenanceApproversGroup(
                self._next_group_id, "other", FixedMaintenanceApproversLogins.DEFAULT_MAINTENANCE_APPROVERS_LOGINS
            )
        )
        self._default_approvers_group_id = self._next_group_id
        self._next_group_id += 1
        return self._default_approvers_group_id

    def _get_approvers_logins_by_known_project_tag(self, tag):
        try:
            if tag == MaintenanceGroupProjectTags.YP_PROJECT_TAG:
                return FixedMaintenanceApproversLogins.DEFAULT_MAINTENANCE_APPROVERS_LOGINS
                # WALLE-4021
                # return abc.get_service_on_duty_logins(self._YP_PROJECT_TAG)
        except abc.ABCInternalError as e:
            raise MaintenanceApproversSchedulerException("Error getting approvers logins from ABC: %s" % e)

        return {
            MaintenanceGroupProjectTags.YABS_PROJECT_TAG: FixedMaintenanceApproversLogins.YABS_MAINTENANCE_APPROVERS_LOGINS,
            MaintenanceGroupProjectTags.YT_PROJECT_TAG: FixedMaintenanceApproversLogins.YT_MAINTENANCE_APPROVERS_LOGINS,
            MaintenanceGroupProjectTags.YT_MASTERS_PROJECT_TAG: FixedMaintenanceApproversLogins.YT_MAINTENANCE_APPROVERS_LOGINS,
        }.get(tag)


def set_group(host_info: ScenarioHostState, group: int):
    host_info.group = group
    host_info.status = HostScenarioStatus.QUEUE
