import logging
from collections import defaultdict

# Looped import problem
import walle.scenario as sc
from sepelib.core import config
from walle.errors import RecoverableError
from walle.maintenance_plot.model import MaintenancePlotModel
from walle.scenario.constants import (
    ScenarioFsmStatus,
    StageName,
    DATACENTER_LABEL_NAME,
    ScenarioWorkStatus,
    WORK_STATUS_LABEL_NAME,
)
from walle.scenario.data_storage.types import HostGroupSource
from walle.scenario.definitions.base import get_data_storage
from walle.scenario.host_groups_builders.constants import HostGroupsSourcesTypes
from walle.scenario.marker import Marker
from walle.scenario.mixins import Stage
from walle.scenario.scenario import Scenario, get_map_maintenance_plot_ids_to_hosts_count, ScriptName
from walle.scenario.stage_info import StageRegistry, StageInfo

log = logging.getLogger(__name__)

_REASON = (
    "Maintenance plot id: {}. \n"
    "Current limit for active hosts in maintenance plot: {}. \n"
    "Current active hosts count: {}. \n"
    "Additional number of active hosts from current scenario: {}. \n"
    "Total number of active hosts with current scenario: {}"
)


class _ScenarioLimitsExceededError(RecoverableError):
    pass


@StageRegistry.register(StageName.AcquirePermission)
class AcquirePermission(Stage):
    """Checks overall number of started scenarios, and number of scenarios started by given issuer"""

    _total_user_limit_name = "max_started_scenarios_per_user"
    _total_started_scenarios_limit_name = "max_started_scenarios"

    def __init__(
        self,
        total_started_scenarios_limit_name=None,
        dc_started_scenarios_limit=None,
        urgently=False,
        with_check_limits_by_maintenance_plots=False,
    ):
        super().__init__(
            total_started_scenarios_limit_name=total_started_scenarios_limit_name,
            dc_started_scenarios_limit=dc_started_scenarios_limit,
            urgently=urgently,
            with_check_limits_by_maintenance_plots=with_check_limits_by_maintenance_plots,
        )

        self.total_started_scenarios_limit_name = total_started_scenarios_limit_name
        self.dc_started_scenarios_limit = dc_started_scenarios_limit
        self.urgently = urgently
        self.with_check_limits_by_maintenance_plots = with_check_limits_by_maintenance_plots

    def run(self, stage_info: StageInfo, scenario: Scenario) -> Marker:
        try:
            self._check_limits(scenario)
        except _ScenarioLimitsExceededError as e:
            log.info(
                "Delay processing of '%s' scenario_id: '%s' by '%s': %s",
                scenario.name,
                scenario.scenario_id,
                scenario.issuer,
                e,
            )
            return Marker.in_progress(message=str(e))
        scenario.set_works_status_label(ScenarioWorkStatus.STARTED)
        return Marker.success(message="Permission granted")

    @staticmethod
    def _scenario_limit_config_key(limit_name: str) -> str:
        return "scenario.{}".format(limit_name)

    def _check_limits(self, scenario: Scenario):
        if self.urgently:
            return

        if self.with_check_limits_by_maintenance_plots:
            self._check_limits_by_maintenance_plots(scenario)

        self._check_limit(
            self._working_scenarios_query(scenario),
            self._total_started_scenarios_limit_name,
            "Too many started scenarios",
        )

        if self.total_started_scenarios_limit_name:
            same_type_query = self._same_type_scenario_query(scenario)
            self._check_limit(
                same_type_query,
                self.total_started_scenarios_limit_name,
                "Too many started scenarios of type {}".format(scenario.scenario_type),
            )

        self._check_limit(
            self._same_user_scenario_query(scenario),
            self._total_user_limit_name,
            "Too many scenarios are started by {}".format(scenario.issuer),
        )

        if self.dc_started_scenarios_limit and DATACENTER_LABEL_NAME in scenario.labels:
            self._check_limit(
                self._same_dc_same_type_scenario_query(scenario),
                self.dc_started_scenarios_limit,
                "Too many started scenarios of type {} in {}".format(
                    scenario.scenario_type, scenario.labels[DATACENTER_LABEL_NAME]
                ),
            )

    def _check_limits_by_maintenance_plots(self, scenario: Scenario):
        map_maintenance_plot_id_to_hosts_count = get_map_maintenance_plot_ids_to_hosts_count(scenario.scenario_id)
        maintenance_plot_allow_limits = self._get_maintenance_plots_allow_limits(
            map_maintenance_plot_id_to_hosts_count.keys()
        )
        maintenance_plots_block_limits = self._get_maintenance_plots_block_limits(
            map_maintenance_plot_id_to_hosts_count.keys()
        )

        all_active_scenarios_ids = self._get_all_active_scenarios_ids()
        total_map_maintenance_plot_id_to_hosts_count = (
            self._get_map_maintenance_plot_id_to_hosts_count_for_all_active_scenarios(
                all_active_scenarios_ids, map_maintenance_plot_id_to_hosts_count.keys()
            )
        )

        for mpid, hosts_count in map_maintenance_plot_id_to_hosts_count.items():
            if mpid in maintenance_plots_block_limits:
                if (
                    hosts_count + total_map_maintenance_plot_id_to_hosts_count[mpid]
                    >= maintenance_plots_block_limits[mpid]
                ):
                    raise _ScenarioLimitsExceededError(
                        "Limit for maintenance plot will be exceeded, plot - {}, limit {}".format(
                            mpid, maintenance_plots_block_limits[mpid]
                        )
                    )

        (
            map_maintenance_plot_id_to_host_group_sources,
            host_group_source_without_plots,
        ) = self._get_map_maintenance_plot_id_to_host_group_sources(scenario)

        for mpid, hosts_count in map_maintenance_plot_id_to_hosts_count.items():
            if mpid in maintenance_plot_allow_limits:
                if mpid in map_maintenance_plot_id_to_host_group_sources:
                    reason = _REASON.format(
                        mpid,
                        maintenance_plot_allow_limits[mpid],
                        total_map_maintenance_plot_id_to_hosts_count[mpid],
                        map_maintenance_plot_id_to_hosts_count[mpid],
                        hosts_count + total_map_maintenance_plot_id_to_hosts_count[mpid],
                    )
                    if (
                        hosts_count + total_map_maintenance_plot_id_to_hosts_count[mpid]
                        < maintenance_plot_allow_limits[mpid]
                    ):
                        map_maintenance_plot_id_to_host_group_sources[mpid].approvement_decision.skip_approvement = True
                    map_maintenance_plot_id_to_host_group_sources[mpid].approvement_decision.reason = reason

        self._write_updated_hosts_group_sources_to_data_storage(
            scenario, map_maintenance_plot_id_to_host_group_sources, host_group_source_without_plots
        )

    def _get_map_maintenance_plot_id_to_hosts_count_for_all_active_scenarios(
        self, scenarios_ids, filter_maintenance_plot_ids: list[int]
    ) -> dict[int, int]:
        result = defaultdict(int)
        for scenario_id in scenarios_ids:
            maintenance_plot_id_to_hosts_count = get_map_maintenance_plot_ids_to_hosts_count(scenario_id)
            for mpid, hosts_count in maintenance_plot_id_to_hosts_count.items():
                if mpid in filter_maintenance_plot_ids:
                    result[mpid] += hosts_count
        return result

    def _get_maintenance_plots_block_limits(self, maintenance_plot_ids: list[int]) -> dict[int, int]:
        result = defaultdict(int)
        for maintenance_plot_model in MaintenancePlotModel.objects(id__in=maintenance_plot_ids):
            maintenance_plot = maintenance_plot_model.as_dataclass()
            if (
                maintenance_plot.common_settings.common_scenarios_settings.dont_allow_start_scenario_if_total_number_of_active_hosts_more_than
            ):
                result[
                    maintenance_plot.id
                ] = (
                    maintenance_plot.common_settings.common_scenarios_settings.dont_allow_start_scenario_if_total_number_of_active_hosts_more_than
                )
        return result

    def _get_maintenance_plots_allow_limits(self, maintenance_plot_ids: list[int]) -> dict[int, int]:
        result = defaultdict(int)
        for maintenance_plot_model in MaintenancePlotModel.objects(id__in=maintenance_plot_ids):
            maintenance_plot = maintenance_plot_model.as_dataclass()
            if maintenance_plot.common_settings.common_scenarios_settings.total_number_of_active_hosts:
                result[
                    maintenance_plot.id
                ] = maintenance_plot.common_settings.common_scenarios_settings.total_number_of_active_hosts
        return result

    def _write_updated_hosts_group_sources_to_data_storage(
        self,
        scenario: Scenario,
        map_maintenance_plot_id_to_host_group_sources: dict[int, HostGroupSource],
        host_group_source_without_plots: list[HostGroupSource],
    ):
        data_storage = get_data_storage(scenario)
        host_group_sources_with_plots = [item for item in map_maintenance_plot_id_to_host_group_sources.values()]
        data_storage.write_host_groups_sources(host_group_sources_with_plots + host_group_source_without_plots)

    def _get_map_maintenance_plot_id_to_host_group_sources(
        self, scenario: Scenario
    ) -> tuple[dict[int, HostGroupSource], list[HostGroupSource]]:
        map_maintenance_plot_id_to_host_group_sources = {}
        host_group_sources_without_maintenance_plots = []
        data_storage = get_data_storage(scenario)
        for host_group_source in data_storage.read_host_groups_sources():
            if host_group_source.source.group_source_type == HostGroupsSourcesTypes.MAINTENANCE_PLOT:
                map_maintenance_plot_id_to_host_group_sources[
                    host_group_source.source.maintenance_plot_id
                ] = host_group_source
            else:
                host_group_sources_without_maintenance_plots.append(host_group_source)
        return map_maintenance_plot_id_to_host_group_sources, host_group_sources_without_maintenance_plots

    @staticmethod
    def _get_all_active_scenarios_ids():
        return [
            scenario["_id"]
            for scenario in Scenario.get_collection().find(
                {
                    "labels.{}".format(WORK_STATUS_LABEL_NAME): {
                        "$nin": [ScenarioWorkStatus.CREATED, ScenarioWorkStatus.ACQUIRING_PERMISSION]
                    },
                    "status": {"$in": [ScenarioFsmStatus.STARTED, ScenarioFsmStatus.PAUSED]},
                    "scenario_type": {"$in": [ScriptName.ITDC_MAINTENANCE, ScriptName.NOC_HARD]},
                },
                {"_id": 1},
            )
        ]

    @staticmethod
    def _check_limit(query, limit_name, message):
        started_scenarios = sc.scenario.Scenario.objects(**query).count()
        max_started_scenarios = config.get_value(AcquirePermission._scenario_limit_config_key(limit_name))

        if started_scenarios > max_started_scenarios:
            message = message.rstrip(".: ")
            raise _ScenarioLimitsExceededError(
                "{}: queue position: {}", message, started_scenarios - max_started_scenarios
            )

    @classmethod
    def _working_scenarios_query(cls, scenario):
        return {
            "status__in": [ScenarioFsmStatus.STARTED, ScenarioFsmStatus.PAUSED],
            "action_time__lte": scenario.action_time,
        }

    @classmethod
    def _same_type_scenario_query(cls, scenario):
        return dict(cls._working_scenarios_query(scenario), scenario_type=scenario.scenario_type)

    @classmethod
    def _same_user_scenario_query(cls, scenario):
        return dict(cls._working_scenarios_query(scenario), issuer=scenario.issuer)

    @classmethod
    def _same_dc_same_type_scenario_query(cls, scenario):
        return dict(
            {"labels__{}".format(DATACENTER_LABEL_NAME): scenario.labels[DATACENTER_LABEL_NAME]},
            **cls._same_type_scenario_query(scenario)
        )
