import logging

from mongoengine import Q
from pymongo import UpdateOne

from sepelib.core.exceptions import LogicalError
from walle.errors import InvalidHostStateError
from walle.hosts import Host, HostState, HostStatus
from walle.models import timestamp
from walle.scenario.constants import HostScenarioStatus, SchedulerName, StageName
from walle.scenario.definitions.base import get_data_storage
from walle.scenario.error_handlers import HssErrorHandler
from walle.scenario.host_stage_info import HostStageInfo
from walle.scenario.iteration_strategy import SequentialActionsStrategyWithRepeat
from walle.scenario.marker import Marker, MarkerStatus
from walle.scenario.mixins import MultiActionStage, CommonParentStageHandler
from walle.scenario.scenario import HostStageStatus
from walle.scenario.scheduler import SchedulerRegistry
from walle.scenario.stage_info import StageRegistry, StageAction, StageStatus

log = logging.getLogger(__name__)


def _host_is_assigned_ready(host, scenario):
    return host.state == HostState.ASSIGNED and host.status == HostStatus.READY


def _host_is_acquired_by_scenario(host, scenario):
    return host.scenario_id == scenario.scenario_id


HOST_IS_ASSIGNED_READY = "_host_is_assigned_ready"
HOST_IS_ACQUIRED_BY_SCENARIO = "_host_is_acquired_by_scenario"

HOST_READINESS_MAPPING = {
    HOST_IS_ASSIGNED_READY: _host_is_assigned_ready,
    HOST_IS_ACQUIRED_BY_SCENARIO: _host_is_acquired_by_scenario,
}


@StageRegistry.register(StageName.HSS)
@StageRegistry.register(StageName.THSS)
def host_scheduler_stage(
    children, schedule_type=SchedulerName.DATACENTER, execution_time=None, greedy=None, host_readiness_str=None
):

    host_root_stage = _host_root_stage(children)
    group_scheduler = GroupScheduler(schedule_type)

    return HostSchedulerStage(
        name=StageName.HSS,
        schedule_type=schedule_type,
        execution_time=execution_time,
        greedy=greedy,
        host_root_stage=host_root_stage,
        group_scheduler=group_scheduler,
        host_readiness_str=host_readiness_str,
        actions=[AcquireHosts(greedy, host_root_stage, group_scheduler), ProcessHostGroup(greedy, group_scheduler)],
    )


@StageRegistry.register("NocMaintenanceHostScheduler")
def noc_maintenance_scheduler_stage(
    children, host_readiness_str=None, schedule_type=SchedulerName.ALL, execution_time=None, greedy=True
):

    host_root_stage = _host_root_stage(children)
    group_scheduler = GroupScheduler(schedule_type)

    return HostSchedulerStage(
        name="NocMaintenanceHostScheduler",
        schedule_type=schedule_type,
        execution_time=execution_time,
        greedy=greedy,
        host_root_stage=host_root_stage,
        group_scheduler=group_scheduler,
        host_readiness_str=host_readiness_str,
        actions=[
            AcquireHosts(
                greedy,
                host_root_stage,
                group_scheduler,
                host_readiness_str=host_readiness_str,
            ),
            ProcessHostGroup(greedy, group_scheduler),
        ],
    )


def _host_root_stage(children_stages):
    if isinstance(children_stages, CommonParentStageHandler):
        return children_stages

    try:
        if len(children_stages) > 1 or not isinstance(children_stages[0], CommonParentStageHandler):
            raise ValueError
    except Exception:
        log.exception("Expected list of single ParentStageHandler, got %r", children_stages)
        raise LogicalError

    return children_stages[0]


class HostSchedulerStage(MultiActionStage):
    STAGE_END_TIME = "stage_end_time"

    def __init__(
        self,
        host_root_stage,
        group_scheduler,
        actions,
        name=None,
        execution_time=None,
        greedy=None,
        schedule_type=SchedulerName.DATACENTER,
        host_readiness_str=None,
    ):
        """Execute HostRootStage for hosts batched by scheduler.
        :param host_root_stage: HostRootStage serializable instance
        :param schedule_type: type of scheduler from SchedulerName.CHOICES
        :param execution_time: time limit for this stage execution. on timeout this stage returns Marker.SUCCESS.
        :param greedy: If True, start executing hosts as soon as they are acquired.
                       If False, acquire all hosts from current group before starting to process any.
        :param host_readiness_str: type of host readiness checker
        """
        if name:
            self.__stage_name = name

        self.host_root_stage = host_root_stage
        self.execution_time = execution_time
        self.group_scheduler = group_scheduler

        super().__init__(
            SequentialActionsStrategyWithRepeat(
                actions=actions,
                stage_prepare=self.prepare,
                stage_completed=self.stage_completed,
                stage_has_more=self.has_more,
            ),
            execution_time=execution_time,
            schedule_type=schedule_type,
            greedy=greedy,
            host_readiness_str=host_readiness_str,
        )

    def run(self, stage_info, scenario, *args, **kwargs):
        result = super().run(stage_info, scenario, *args, **kwargs)
        if result.status == MarkerStatus.SUCCESS:
            scenario.clean_hosts_stage_info()
        return result

    def serialize(self, uid="0"):
        stage_info = super().serialize(uid=uid)
        stage_info.action_type = StageAction.PREPARE
        stage_info.stages = [self.host_root_stage.serialize(uid="{}.0".format(uid))]
        return stage_info

    def prepare(self, stage_info, scenario, *args, **kwargs):
        # Run scheduler and set initial group number
        self.group_scheduler.schedule_hosts_for_scenario(scenario)

        if self.execution_time is not None:
            stage_info.data[self.STAGE_END_TIME] = timestamp() + self.execution_time

        return Marker.success()

    def has_more(self, stage_info, scenario):
        # TODO: we need to find a way to not use scenario here
        return self.group_scheduler.is_group_available(scenario)

    def _timed_out(self, stage_info):
        return self.execution_time is not None and timestamp() >= stage_info.data[self.STAGE_END_TIME]

    def stage_completed(self, stage_info, scenario):
        if self._timed_out(stage_info):
            return Marker.success()

        return Marker.in_progress()

    @property
    def name(self):
        try:
            return self.__stage_name
        except AttributeError:
            return self.__class__.__name__

    def __eq__(self, other):
        return super().__eq__(other) and self.host_root_stage == other.host_root_stage

    def __repr__(self):
        return "<{cls}(name='{name}', params={params!r}, host_root_stage={host_root_stage!r})>".format(
            cls=type(self).__name__, name=self.name, params=self.params, host_root_stage=self.host_root_stage
        )


class GroupScheduler:
    def __init__(self, schedule_type):
        self.schedule_type = schedule_type

    def schedule_hosts_for_scenario(self, scenario):
        scheduler_cls = SchedulerRegistry.get(self.schedule_type)
        if self.schedule_type == SchedulerName.MAINTENANCE_APPROVERS:
            data_storage = get_data_storage(scenario)
            scheduler = scheduler_cls(scenario.hosts, data_storage)
            scenario.hosts = scheduler.schedule()
        else:
            scenario.hosts = scheduler_cls.schedule(scenario.hosts)
        scenario.current_group = 0

    @staticmethod
    def get_current_hosts_group(host_infos, current_group_number, uses_uuid_keys=False):
        current_group_host_keys = [
            host_key
            for host_key, host_info in host_infos.items()
            if host_info.group == current_group_number and host_info.status != HostScenarioStatus.DONE
        ]

        if not current_group_host_keys:
            return Host.objects.none()

        host_key = "uuid" if uses_uuid_keys else "inv"
        query = {"{}__in".format(host_key): current_group_host_keys}
        return Host.objects(**query)

    @staticmethod
    def is_group_available(scenario):
        if scenario.hosts:
            # TODO think about saving total group count
            max_group_num = max(host.group for host in scenario.hosts.values())
            return scenario.current_group <= max_group_num
        else:
            return False


class AcquireHosts:
    name = StageAction.ACTION

    def __init__(self, greedy, host_root_stage, group_scheduler, host_readiness_str=None):
        self.greedy = greedy
        self.host_root_stage = host_root_stage
        self.group_scheduler = group_scheduler

        if host_readiness_str:
            self._host_ready_for_processing = HOST_READINESS_MAPPING[host_readiness_str]

    def acquire_hosts(self, stage_info, scenario):
        # Get hosts group from scheduler and save serialized HRS
        busy_hosts = {}
        available_hosts = []

        group = self._get_current_hosts_group(scenario)
        for host in group:
            if self._is_host_scheduled(host, scenario):
                continue

            if self._is_host_acquired_by_another_scenario(host, scenario):
                # current strategy: store this information, proceed for other hosts, retry until all hosts available
                busy_hosts[host.uuid] = host.scenario_id
                continue

            if not self._host_ready_for_processing(host, scenario):
                # current strategy: store this information, proceed for other hosts, retry until all hosts available
                busy_hosts[host.uuid] = host.scenario_id
                continue

            available_hosts.append((host.uuid, host.inv))

        stage_info = self._get_host_stage_info(stage_info)

        updated = 0
        for (host_uuid, host_inv) in available_hosts:

            if Host.objects(Q(scenario_id__exists=None) | Q(scenario_id=scenario.scenario_id), uuid=host_uuid).modify(
                set__scenario_id=scenario.scenario_id
            ):
                HostStageInfo(host_uuid=host_uuid, scenario_id=scenario.scenario_id, stage_info=stage_info).save(
                    force_insert=True
                )

                host_key = host_uuid if scenario.uses_uuid_keys else host_inv
                scenario.set_host_info_status(host_key, HostScenarioStatus.ACQUIRED)
                updated += 1

        if updated:
            log.info("Acquired %s hosts from total %s for scenario %s.", updated, group.count(), scenario.scenario_id)

        if updated != len(available_hosts):
            # we lost some hosts on the way.
            log.info(
                "Tried to acquire %s hosts from %s for scenario %s, got only %s hosts. Will try again later.",
                len(available_hosts),
                group.count(),
                scenario.scenario_id,
                updated,
            )
            return self.greedy_action_marker()

        if busy_hosts:
            # we missed these.
            log.info(
                "Can't acquire %s hosts from %s for scenario %s: occupied by other tasks. Will try again later.",
                len(busy_hosts),
                group.count(),
                scenario.scenario_id,
            )
            return self.greedy_action_marker()

        return Marker.success()

    def _get_current_hosts_group(self, scenario):
        return self.group_scheduler.get_current_hosts_group(
            scenario.hosts, scenario.current_group, scenario.uses_uuid_keys
        )

    @staticmethod
    def _is_host_scheduled(host, scenario):
        return (
            host.scenario_id == scenario.scenario_id
            and HostStageInfo.objects(host_uuid=host.uuid, scenario_id=scenario.scenario_id).count()
        )

    @staticmethod
    def _is_host_acquired_by_another_scenario(host, scenario):
        return host.scenario_id and host.scenario_id != scenario.scenario_id

    @staticmethod
    def _host_ready_for_processing(host, scenario):
        return True

    def _get_host_stage_info(self, stage_info):
        return self.host_root_stage.serialize(uid="{}.0".format(stage_info.uid))

    def greedy_action_marker(self):
        return Marker.success() if self.greedy else Marker.in_progress()

    def __call__(self, stage_info, scenario):
        return self.acquire_hosts(stage_info, scenario)


class ProcessHostGroup:
    name = StageAction.CHECK

    def __init__(self, greedy, group_scheduler):
        self.greedy = greedy
        self.group_scheduler = group_scheduler

    def process_current_group(self, stage_info, scenario):
        # Run HRS on hosts in current group that are in status "queue" or "processing"
        hosts = self._get_current_hosts_group(scenario)
        if hosts.count():
            uuids = [x.host_uuid for x in HostStageInfo.objects(scenario_id=scenario.scenario_id).only("host_uuid")]
            hosts_to_execute = hosts.filter(uuid__in=uuids)
            if hosts_to_execute.count():
                if hosts_to_execute.count() < hosts.count():
                    log.info(
                        "Executing stages on %s out of %s hosts for scenario %s group %s",
                        hosts_to_execute.count(),
                        hosts.count(),
                        scenario.scenario_id,
                        scenario.current_group,
                    )
                self._execute_host_stages(list(hosts_to_execute), stage_info, scenario)
            return self.greedy_action_marker()
        else:
            # Group is processed, proceed to the next one
            log.info("All hosts for scenario %s group %s processed.", scenario.scenario_id, scenario.current_group)

            scenario.current_group += 1
            return Marker.success()

    @classmethod
    def _execute_host_stages(cls, hosts, stage_info, scenario):
        hss_error_handler = HssErrorHandler()

        hrs_stage_info = stage_info.stages[0]
        hrs_stage_info.set_stage_processing()
        all_hsis = {hsi.host_uuid: hsi for hsi in HostStageInfo.objects(host_uuid__in=[host.uuid for host in hosts])}
        hms = hrs_stage_info.deserialize()

        all_hosts_processed = all(
            [
                cls._process_single_host(scenario, host, all_hsis, hms, hss_error_handler, hrs_stage_info)
                for host in hosts
            ]
        )
        if all_hosts_processed:
            hrs_stage_info.set_stage_finished()

        cls._bulk_update_host_stage_infos(scenario, all_hsis)
        cls._set_statuses_for_host_stages(hrs_stage_info, hosts, all_hosts_processed)

        hss_error_handler.raise_exception()

    @classmethod
    def _process_single_host(cls, scenario, host, all_hsis, hms, hss_error_handler, scenario_stage_info):
        host_key = host.uuid if scenario.uses_uuid_keys else host.inv
        scenario.set_host_info_status(host_key, HostScenarioStatus.PROCESSING)

        is_host_processed = False

        try:
            hsi_document = all_hsis[host.uuid]
            host_stage_info = hsi_document.stage_info

            with hss_error_handler(host_stage_info, scenario, host):
                marker = hms.run(host_stage_info, scenario, host, scenario_stage_info)

                if marker.status == MarkerStatus.SUCCESS:
                    host_stage_info.set_stage_finished()

                    scenario.set_host_info_status(host_key, HostScenarioStatus.DONE)

                    is_host_processed = True
        except InvalidHostStateError as e:
            log.info("Got host in wrong state/status, user will fix it: %s", e)
        except Exception as e:
            log.exception(
                "Got exception during scenario #%s processing for host #%s: %s", scenario.scenario_id, host.inv, e
            )

        return is_host_processed

    def _get_current_hosts_group(self, scenario):
        return self.group_scheduler.get_current_hosts_group(
            scenario.hosts, scenario.current_group, scenario.uses_uuid_keys
        )

    @staticmethod
    def _update_host_stage_info(scenario, hsi, host_stage_info):
        revision = hsi.revision
        query = dict(revision=revision, scenario_id=scenario.scenario_id, host_uuid=hsi.host_uuid)
        HostStageInfo.objects(**query).modify(set__stage_info=host_stage_info, set__revision=revision + 1)

    @staticmethod
    def _bulk_update_host_stage_infos(scenario, all_hsis):
        bulk_operations = []

        for hsi in all_hsis.values():
            bulk_operations.append(
                UpdateOne(
                    {"_id": hsi.host_uuid, "revision": hsi.revision, "scenario_id": scenario.scenario_id},
                    {"$set": {"stage_info": hsi.stage_info.to_mongo().to_dict()}, "$inc": {"revision": 1}},
                )
            )

        HostStageInfo._get_collection().bulk_write(bulk_operations, ordered=False)

    @staticmethod
    def _set_statuses_for_host_stages(stage_info, hosts, all_hosts_processed=False):
        hosts_uuids = {host.uuid for host in hosts}
        for child_stage_info in stage_info.stages:
            if all_hosts_processed:
                child_stage_info.set_stage_finished()
                continue

            if child_stage_info.status == StageStatus.FINISHED:
                continue

            if set(child_stage_info.hosts.keys()).issuperset(hosts_uuids):
                if all(
                    [host_info["status"] == HostStageStatus.FINISHED for host_info in child_stage_info.hosts.values()]
                ):
                    child_stage_info.set_stage_finished()

    def greedy_action_marker(self):
        return Marker.success() if self.greedy else Marker.in_progress()

    def __call__(self, stage_info, scenario):
        return self.process_current_group(stage_info, scenario)


def _scenario_label_setter(label_name, label_value, allowed_values=None):
    def _set_scenario_label(scenario, _label_name=label_name, _label_value=label_value, _allowed_values=allowed_values):
        if _allowed_values and _label_name in scenario.labels and scenario.labels[_label_name] in _allowed_values:
            log.info("Setting '%s'='%s' for scenario #%s.", _label_name, _label_value, scenario.scenario_id)
            scenario.labels[_label_name] = _label_value

    return _set_scenario_label
