import dataclasses
import typing

from walle.scenario.constants import HostScenarioStatus, StageName
from walle.scenario.data_storage.data_storage import DefaultDataStorage
from walle.scenario.definitions.base import get_data_storage
from walle.scenario.error_handlers import GroupStageErrorHandler
from walle.scenario.marker import Marker, MarkerStatus
from walle.scenario.mixins import ParentStageHandler, Stage
from walle.scenario.scenario import Scenario
from walle.scenario.scheduler import MaintenanceApproversScheduler
from walle.scenario.stage_info import StageInfo, StageRegistry, StageStatus


@dataclasses.dataclass
class GroupStagesContainer:
    group_id: int = 0
    group_name: typing.Optional[str] = None
    number_of_hosts: int = 0
    stages: typing.List[StageInfo] = dataclasses.field(default_factory=list)
    current_stage_index: int = 0

    def all_stages_done(self) -> bool:
        return self.current_stage_index == len(self.stages)

    def get_current_stage(self) -> StageInfo:
        return self.stages[self.current_stage_index]

    def serialize(self):
        return {
            "group_id": self.group_id,
            "group_name": self.group_name,
            "number_of_hosts": self.number_of_hosts,
            "stages": self.stages,
            "current_stage_index": self.current_stage_index,
        }


@StageRegistry.register(StageName.MaintenanceApproversWorkflowStage)
class MaintenanceApproversWorkflowStage(ParentStageHandler, Stage):

    _group_stages_containers_data_field_name = "group_stages_containers"

    def run(self, stage_info: StageInfo, scenario: Scenario):
        scenario_data_storage = get_data_storage(scenario)

        group_stages_containers = self._read_group_stage_containers(stage_info)

        if not group_stages_containers:
            # 'MaintenanceApproversGroups' in data_storage may have been created earlier, when scenario was created.
            # If so, it means that hosts are already splitted into groups, and here we only want to (re)set their
            # `HostScenarioStatus` to `QUEUE`.
            if scenario_data_storage.read():
                self._set_host_scenario_statuses_to_queue(scenario)
            else:
                # Old behavior: split hosts into groups and set `HostScenarioStatus`'es to `QUEUE`.
                self._split_hosts_into_maintenance_approvers_groups(scenario, scenario_data_storage)

            if self.params.get('stage_map'):
                group_stages_containers = self._generate_group_stages_containers(
                    scenario_data_storage, scenario, stage_info
                )
                message = "Dynamically created groups"
            else:
                # Create copies of child stages for each group and save them in own 'data', write group stage containers
                # copies and stop current iteration. This is needed to make copies of 'stages' inside
                # 'group_stages_containers'.
                # On the next iteration of this stage they will be read and recreated as separate copies.
                # Explicitly copying 'stages' with 'copy.deepcopy()' results in MongoEngine error.
                # There is a fix https://github.com/MongoEngine/mongoengine/pull/2495 the issue, but it is available only
                # since mongoengine version 0.23.0. Now (at 15.04.2021) we have 0.17.0 in contrib.
                group_stages_containers = self._create_group_stages_containers(
                    stage_info["stages"], scenario_data_storage, scenario
                )
                message = "Created groups"

            self._write_group_stage_containers(stage_info, group_stages_containers)
            return Marker.in_progress(message=message)

        group_error_handler = GroupStageErrorHandler()

        for group_stage_container in group_stages_containers:
            # Check if all group stages are done.
            if group_stage_container.all_stages_done():
                continue

            # Get an already created object of a class at index in main children tree.
            active_stage_obj = group_stage_container.stages[group_stage_container.current_stage_index].deserialize()
            # Give the class stages of current groups' stages copy to use, and run it.
            active_stage_info = group_stage_container.get_current_stage()

            with group_error_handler(active_stage_info, scenario, group_stage_container.group_id):

                if active_stage_info.status == StageStatus.FINISHED:
                    result_marker = Marker.success(message="The stage was completed manually")
                else:
                    active_stage_info.set_stage_processing()

                    result_marker = active_stage_obj.run(active_stage_info, scenario, group_stage_container.group_id)

                active_stage_info.set_stage_msg(result_marker.message)
                if result_marker.status == MarkerStatus.SUCCESS:
                    active_stage_info.set_stage_finished()
                    group_stage_container.current_stage_index += 1

        self._write_group_stage_containers(stage_info, group_stages_containers)

        group_error_handler.raise_exception()

        # Check if there are copies of stages that are not fully executed yet.
        for group_stage_container in group_stages_containers:
            if not group_stage_container.all_stages_done():
                return Marker.in_progress(message="Processing child stages")

        # Finish this stage when all groups' copies of stages are finished.
        return Marker.success(message="All groups finished")

    def _read_group_stage_containers(self, stage_info: StageInfo) -> typing.List[GroupStagesContainer]:
        """
        Reads group stage containers from StageInfo's 'data' field.
        """
        return [
            GroupStagesContainer(**container_dict)
            for container_dict in stage_info.get_data(self._group_stages_containers_data_field_name, [])
        ]

    def _write_group_stage_containers(
        self, stage_info: StageInfo, group_stage_containers: typing.List[GroupStagesContainer]
    ):
        """
        Writes group stage containers to StageInfo's 'data' field.
        """
        containers_dicts = [container.serialize() for container in group_stage_containers]
        stage_info.set_data(self._group_stages_containers_data_field_name, containers_dicts)

    @staticmethod
    def _split_hosts_into_maintenance_approvers_groups(scenario: Scenario, data_storage: DefaultDataStorage):
        """
        Use scheduler to split hosts into groups.
        """
        scheduler = MaintenanceApproversScheduler(scenario.hosts, data_storage)
        scenario.hosts = scheduler.schedule()

    @staticmethod
    def _set_host_scenario_statuses_to_queue(scenario: Scenario):
        """
        Sets scenario's host statuses to `QUEUE`.
        """
        for host_info in scenario.hosts.values():
            host_info.status = HostScenarioStatus.QUEUE

    @staticmethod
    def _create_group_stages_containers(
        stages: typing.List[StageInfo], data_storage: DefaultDataStorage, scenario: Scenario
    ) -> typing.List[GroupStagesContainer]:
        """
        Make a copy of current 'stage_info.stages' for each maintenance approvers group.
        """
        return [
            GroupStagesContainer(
                group_id=group.group_id,
                group_name=group.name,
                number_of_hosts=len(
                    [host_info for host_info in scenario.hosts.values() if host_info.group == group.group_id]
                ),
                stages=stages,
                current_stage_index=0,
            )
            for group in data_storage.read()
        ]

    def _generate_group_stages_containers(
        self,
        data_storage: DefaultDataStorage,
        scenario: Scenario,
        stage_info: StageInfo,
    ) -> typing.List[GroupStagesContainer]:
        """
        Generates a list of stages for each maintenance approvers group.
        """
        return [
            GroupStagesContainer(
                group_id=host_group_source.group_id,
                group_name=host_group_source.source.get_group_source_name(),
                number_of_hosts=len(
                    [
                        host_info
                        for host_info in scenario.hosts.values()
                        if host_info.group == host_group_source.group_id
                    ]
                ),
                stages=self.generate_stage_infos(scenario, stage_info, host_group_source),
                current_stage_index=0,
            )
            for host_group_source in data_storage.read_host_groups_sources()
        ]

    def generate_stage_infos(self, scenario, stage_info, host_group_source):
        maintenance_plot = host_group_source.source.get_maintenance_plot()
        scenario_settings = maintenance_plot.get_scenario_settings(scenario.scenario_type).to_dict()

        stages = self.generate_stages(self.params['stage_map'], scenario_settings)
        return [child.serialize(uid="{}.{}".format(stage_info.uid, idx)) for idx, child in enumerate(stages)]
