import logging
import typing

from mongoengine import Q

from sepelib.core.exceptions import LogicalError
from walle.errors import InvalidHostStateError
from walle.hosts import Host
from walle.scenario.constants import HostScenarioStatus, StageName
from walle.scenario.error_handlers import HssErrorHandler
from walle.scenario.errors import ErrorList
from walle.scenario.host_stage_info import HostStageInfo
from walle.scenario.marker import Marker, MarkerStatus
from walle.scenario.mixins import CommonParentStageHandler, HostGroupStage
from walle.scenario.scenario import Scenario, ScenarioHostState, HostStageStatus
from walle.scenario.stage_info import StageInfo, StageRegistry, StageStatus
from walle.scenario.stages import HostRootStage

log = logging.getLogger(__name__)


@StageRegistry.register(StageName.HostGroupSchedulerStage)
class HostGroupSchedulerStage(CommonParentStageHandler, HostGroupStage):
    """
    Acquires hosts of a given group and runs host stages on them.

    This stage must have exactly one child, and this child must be 'HostRootStage'.
    """

    children = None
    are_all_hosts_acquired = "are_all_hosts_acquired"

    def __init__(self, children: typing.List[HostRootStage], **params):
        super(CommonParentStageHandler, self).__init__(**params)

        if len(children) != 1 or not isinstance(children[0], HostRootStage):
            raise LogicalError

        self.children = children

    def run(self, stage_info: StageInfo, scenario: Scenario, host_group_id: int):
        """
        1. Reads hosts of a given group that are not yet done.
        2. Acquires these hosts - binds them to this scenario ID and creates HostStageInfo documents in database.
        3. Executes host's stages (child stages of this one) on each host.
        """
        # In '__init__' we ensured that HostRootStage is the only possible child stage.
        host_root_stage = stage_info.stages[0]

        # Get hosts of the group.
        hosts = self._get_hosts_of_the_group_that_are_not_done(scenario.hosts, host_group_id)

        #  Acquire hosts.
        if not stage_info.get_data(self.are_all_hosts_acquired, False):
            result = self._acquire_hosts(hosts, host_root_stage, scenario)
            if result.status != MarkerStatus.SUCCESS:
                return result
            stage_info.set_data(self.are_all_hosts_acquired, True)

        # Run HostStageInfo's after all hosts are acquired.
        all_hosts_are_processed = self._execute_host_stages(hosts, host_root_stage, scenario)

        # Clean up when all hosts are done.
        if not all_hosts_are_processed:
            return Marker.in_progress(message="Not all hosts are done, waiting...")

        # Remove HostStageInfo's of all hosts of the group.
        self._remove_host_stage_info_documents(scenario, host_group_id)
        return Marker.success(message="All hosts are done.")

    @classmethod
    def _execute_host_stages(cls, hosts: typing.List[Host], host_root_stage: StageInfo, scenario: Scenario) -> bool:
        """
        Executes host's stages with error catching.
        Returns 'True' when all hosts of the group will successfully execute all of their host stages,
        'False' otherwise.
        """
        host_root_stage.set_stage_processing()
        error_handler = HssErrorHandler()
        # Run host's stages FSM on each host.
        all_hosts_are_processed = all(
            [cls._process_single_host(scenario, host, error_handler, host_root_stage) for host in hosts]
        )
        if all_hosts_are_processed:
            host_root_stage.set_stage_finished()

        cls._set_finished_status_on_host_root_stage_children(host_root_stage, hosts, all_hosts_are_processed)

        if error_handler.errors:
            raise ErrorList(error_handler.errors)

        return all_hosts_are_processed

    @classmethod
    def _process_single_host(
        cls, scenario: Scenario, host: Host, error_handler: HssErrorHandler, host_root_stage: StageInfo
    ) -> bool:
        """Executes host stages of a given host."""
        scenario.set_host_info_status(host.uuid, HostScenarioStatus.PROCESSING)
        is_host_processed = False

        try:
            host_stage_info_document = HostStageInfo.objects.get(host_uuid=host.uuid)
            host_stage_info = host_stage_info_document.stage_info
            host_stages = host_stage_info.deserialize()

            with error_handler(host_stage_info, scenario, host):
                marker = host_stages.run(host_stage_info, scenario, host, host_root_stage)
                if marker.status == MarkerStatus.SUCCESS:
                    host_stage_info.set_stage_finished()
                    scenario.set_host_info_status(host.uuid, HostScenarioStatus.DONE)
                    is_host_processed = True

            cls._update_host_stage_info_document(scenario, host_stage_info_document, host_stage_info)

        except InvalidHostStateError as e:
            log.info("Got host in wrong state/status, user will fix it: %s", e)

        except Exception as e:
            log.exception(
                "Got exception during scenario #%s processing for host #%s: %s", scenario.scenario_id, host.inv, e
            )

        return is_host_processed

    @staticmethod
    def _acquire_hosts(hosts: typing.List[Host], host_root_stage: StageInfo, scenario: Scenario) -> MarkerStatus:
        """
        Acquires hosts by the scenario: sets scenario ID to Host's documents and creates new HostStageInfo documents
        in database. Checks if hosts are acquired by another scenarios. Returns 'Marker.success()' when all hosts
        of the group are acquired by this scenario.
        """
        uuids_of_hosts_acquired_by_another_scenario = []

        # Try to acquire each host.
        for host in hosts:
            # Check if host is acquired by this scenario and its HostStageInfo document present in database.
            if (
                host.scenario_id == scenario.scenario_id
                and HostStageInfo.objects(host_uuid=host.uuid, scenario_id=scenario.scenario_id).count()
            ):
                # It ts already acquired, so skip it.
                continue

            # Check if host is acquired by another scenario.
            if host.scenario_id and host.scenario_id != scenario.scenario_id:
                # Save its UUID to log number of such hosts later.
                uuids_of_hosts_acquired_by_another_scenario.append(host.uuid)
                continue

            # Acquire host. Try to select host by UUID with explicitly stated condition "no scenario ID set
            # or this scenario ID is set", and set scenario ID.
            if Host.objects(Q(scenario_id__exists=None) | Q(scenario_id=scenario.scenario_id), uuid=host.uuid).modify(
                set__scenario_id=scenario.scenario_id
            ):
                # If Host object was found and scenario ID was written to it - save new host's HostStageInfo document to
                # database, and set host's status in scenario's 'hosts' to 'acquired'.
                HostStageInfo(host_uuid=host.uuid, scenario_id=scenario.scenario_id, stage_info=host_root_stage).save(
                    force_insert=True
                )
                scenario.set_host_info_status(host.uuid, HostScenarioStatus.ACQUIRED)
            else:
                # Fail stage if host was not selected and modified.
                return Marker.failure(
                    message="Can not set scenario ID to Host with UUID '%s', will try again on next "
                    "iteration." % host.uuid
                )

        # Issue log message and return 'Marker.in_progress()' if there are hosts acquired by another scenarios.
        if uuids_of_hosts_acquired_by_another_scenario:
            log.info(
                "There are %s hosts acquired by another scenarios, will try to acquire them again on next "
                "iterarion" % len(uuids_of_hosts_acquired_by_another_scenario)
            )
            return Marker.in_progress(message="Some hosts acquired by another scenario, try a little bit later.")

        return Marker.success(message="Successfully acquired all hosts.")

    @staticmethod
    def _get_hosts_of_the_group_that_are_not_done(
        host_infos: typing.Dict[str, ScenarioHostState], host_group_id: int
    ) -> typing.List[Host]:
        """Returns hosts of a given group, which are not in 'HostScenarioStatus.DONE' status."""
        uuids = [
            uuid
            for uuid, host_info in host_infos.items()
            if host_info.group == host_group_id and host_info.status != HostScenarioStatus.DONE
        ]

        if not uuids:
            return Host.objects.none()

        query = {"uuid__in": uuids}
        return Host.objects(**query)

    # noinspection DuplicatedCode
    @staticmethod
    def _set_finished_status_on_host_root_stage_children(
        host_root_stage: StageInfo, hosts: typing.List[Host], all_hosts_processed: bool
    ):
        """Sets 'HostRootStage' child stages status to 'StageStatus.FINISHED' if all hosts finished executing it."""
        uuids = {host.uuid for host in hosts}
        for child_stage_info in host_root_stage.stages:
            if all_hosts_processed:
                child_stage_info.set_stage_finished()
                continue

            if child_stage_info.status == StageStatus.FINISHED:
                continue

            if set(child_stage_info.hosts.keys()).issuperset(uuids):
                if all(
                    [host_info["status"] == HostStageStatus.FINISHED for host_info in child_stage_info.hosts.values()]
                ):
                    child_stage_info.set_stage_finished()

    @staticmethod
    def _update_host_stage_info_document(
        scenario: Scenario, host_stage_info_document: HostStageInfo, host_stage_info: StageInfo
    ):
        """Updates 'stage_info' field and revision in host's 'HostStageInfo' document."""
        revision = host_stage_info_document.revision
        query = dict(revision=revision, scenario_id=scenario.scenario_id, host_uuid=host_stage_info_document.host_uuid)
        # Will be modified only if revision did not changed.
        HostStageInfo.objects(**query).modify(set__stage_info=host_stage_info, set__revision=revision + 1)

    @staticmethod
    def _remove_host_stage_info_documents(scenario: Scenario, host_group_id: int):
        """Removes 'HostStageInfo' documents of hosts belonging to the host group."""
        host_group_uuids = [uuid for uuid, host_info in scenario.hosts.items() if host_info.group == host_group_id]
        for uuid in host_group_uuids:
            HostStageInfo.objects(scenario_id=scenario.scenario_id, host_uuid=uuid).delete()
