# -*- coding: utf-8 -*-

import dataclasses
import logging
from typing import Dict, Iterable, List, Set

from travel.hotels.content_manager.data_model.storage import (
    DispatchableEntity, StageStatus
)
from travel.hotels.content_manager.data_model.types import StageResult
from travel.hotels.content_manager.lib.common import str_from_set, str_to_set


LOG = logging.getLogger(__name__)


class CycleStageRequirementsException(Exception):
    pass


class Dispatcher:

    def dispatch_entities(
        self,
        entities: Iterable[DispatchableEntity],
        final_stage: str,
    ) -> Dict[str, List[DispatchableEntity]]:
        LOG.info('Start dispatching')
        entity_by_stage = dict()

        for entity in entities:
            LOG.info(f'Dispatching {entity}')
            required_stages = str_to_set(entity.required_stages)
            finished_stages = str_to_set(entity.finished_stages)
            failed_stages = self.get_entity_failed_stages(entity)

            if finished_stages & failed_stages:
                LOG.info('No more stages to wait')
                stage_entities = entity_by_stage.setdefault(final_stage, list())
                stage_entities.append(entity)
                continue

            stages_to_wait = required_stages - finished_stages
            if not stages_to_wait:
                LOG.info('No more stages to wait')
                stage_entities = entity_by_stage.setdefault(final_stage, list())
                stage_entities.append(entity)
                continue

            stages_to_send = set()
            for stage in stages_to_wait:
                try:
                    stages_to_send.update(self.get_entity_stages_to_send(entity, stage, set(), finished_stages))
                except CycleStageRequirementsException:
                    raise Exception(f'Cycle stage requirements for {entity}')

            running_stages = self.get_entity_running_stages(entity)
            for stage in stages_to_send - running_stages:
                stage_entities = entity_by_stage.setdefault(stage, list())
                stage_entities.append(entity)

        return entity_by_stage

    @staticmethod
    def get_entity_running_stages(entity: DispatchableEntity) -> Set[str]:
        running_stages = set()

        for field in dataclasses.fields(entity):
            if field.type is not StageStatus:
                continue
            if getattr(entity, field.name) == StageStatus.NOTHING_TO_DO:
                continue
            stage_name = field.name[len('status_'):]
            running_stages.add(stage_name)

        return running_stages

    @staticmethod
    def get_entity_failed_stages(entity: DispatchableEntity) -> Set[str]:
        failed_stages = set()

        for field in dataclasses.fields(entity):
            if field.type is not StageResult:
                continue
            if getattr(entity, field.name) == StageResult.FAILED:
                stage_name = field.name[:-len('_result')]
                failed_stages.add(stage_name)

        return failed_stages

    def get_entity_stages_to_send(
        self,
        entity: DispatchableEntity,
        stage: str,
        path_required_stages: Set[str],
        finished_stages: Set[str],
    ) -> Set[str]:
        stages_to_send = set()

        stage_required_stages = self.get_stage_required_stages(entity, stage)
        still_required_stages = stage_required_stages - finished_stages
        if not still_required_stages:
            return {stage}

        for current_stage in still_required_stages:
            if current_stage in path_required_stages:
                raise CycleStageRequirementsException()
            current_required_stages = path_required_stages | {current_stage}
            stages_to_send.update(self.get_entity_stages_to_send(
                entity,
                current_stage,
                current_required_stages,
                finished_stages,
            ))

        return stages_to_send

    @staticmethod
    def get_stage_required_stages(entity: DispatchableEntity, stage: str) -> Set[str]:
        return str_to_set(getattr(entity, f'{stage}_required_stages'))

    @staticmethod
    def add_entity_required_stages(entity: DispatchableEntity, stages: List[str]) -> None:
        required_stages = str_to_set(entity.required_stages)
        required_stages.update(stages)
        entity.required_stages = str_from_set(required_stages)

    @staticmethod
    def add_stage_required_stages(entity: DispatchableEntity, stage: str, stages: List[str]) -> None:
        stage_field = f'{stage}_required_stages'
        required_stages = str_to_set(getattr(entity, stage_field))
        required_stages.update(stages)
        setattr(entity, stage_field, str_from_set(required_stages))

    @staticmethod
    def is_entity_processed(entity: DispatchableEntity) -> bool:
        required_stages = str_to_set(entity.required_stages)
        finished_stages = str_to_set(entity.finished_stages)
        if Dispatcher.get_entity_failed_stages(entity):
            return True
        return not bool(required_stages - finished_stages)

    @staticmethod
    def is_processing_successful(entity: DispatchableEntity) -> bool:
        finished_stages = str_to_set(entity.finished_stages)
        failed_stages = Dispatcher.get_entity_failed_stages(entity)
        return not bool(finished_stages & failed_stages)
