import logging
from abc import abstractmethod

from sepelib.core.exceptions import LogicalError
from walle.hosts import Host
from walle.maintenance_plot.exceptions import MaintenancePlotScenarioSettingsTypeDoesntExists
from walle.models import timestamp
from walle.projects import DEFAULT_CMS_NAME
from walle.scenario.constants import DEFAULT_STAGE_DESCRIPTION
from walle.scenario.host_stage_info import HostStageInfo
from walle.scenario.iteration_strategy import SequentialIterationStageStrategy, HostSequentialIterationStageStrategy
from walle.scenario.marker import Marker
from walle.scenario.scenario import Scenario
from walle.scenario.stage_info import StageInfo, StageStatus

log = logging.getLogger(__name__)


class CommonParentStageHandler:
    children = None

    def __init__(self, children, **params):
        self.params = params
        if not (children or params.get('stage_map') or params.get('stage_generator')):
            raise LogicalError

        self.children = children
        self.iteration_strategy = SequentialIterationStageStrategy()

    def serialize(self, uid="0"):
        return StageInfo(
            uid=uid,
            seq_num=0,
            name=self.name,
            params=self.params,
            status_time=timestamp(),
            description=get_description(self),
            stages=[child.serialize(uid="{}.{}".format(uid, idx)) for idx, child in enumerate(self.children)],
        )

    def __eq__(self, other):
        return super().__eq__(other) and self.children == other.children

    def __repr__(self):
        format_str = "<{cls}(params={params!r}, children={children!r})>"
        if type(self).__name__ != self.name:
            format_str = "<{cls}(name={name!r}, params={params!r}, children={children!r})>"

        return format_str.format(cls=type(self).__name__, name=self.name, params=self.params, children=self.children)

    def generate_stages(self, stage_descs, scenario_settings):
        return [
            self.make_stage_from_desc(stage_desc, scenario_settings)
            for stage_desc in stage_descs
            if self.stage_desc_is_conditioned(stage_desc, scenario_settings)
        ]

    def make_stage_from_desc(self, stage_desc, scenario_settings):
        if self.stage_desc_is_conditioned(stage_desc, scenario_settings):
            stage_params = stage_desc.params or {}
            stage_params = stage_params.copy()

            if stage_params.get('children'):
                stage_params['children'] = self.generate_stages(stage_params['children'], scenario_settings)
            return stage_desc.stage_class(**stage_params)

    def stage_desc_is_conditioned(self, stage_desc, scenario_settings):
        if not stage_desc.conditions:
            return True
        for if_any_condition in stage_desc.conditions.get('if_any', []):
            ((condition_key, condition_val),) = if_any_condition.items()
            if scenario_settings[condition_key] == condition_val:
                return True
        return False


class ParentStageHandler(CommonParentStageHandler):
    def execute_current_stage(self, stage_info, scenario, *args, **kwargs):
        active_stage, active_stage_info = self.iteration_strategy.get_active_stage(stage_info, self.children)
        log.info("Executing %s with args=%s; kwargs=%s", active_stage_info, args, kwargs)
        if active_stage_info.status == StageStatus.FINISHED:
            marker = Marker.success(message="The stage was completed manually")
        else:
            marker = active_stage.run(active_stage_info, scenario, *args, **kwargs)
        log.info("Got result %s for stage %s (args=%s; kwargs=%s)", marker, active_stage_info, args, kwargs)
        active_stage_info.set_stage_msg(marker.message)
        return self.iteration_strategy.process_marker(marker, stage_info, active_stage_info)

    @staticmethod
    def get_all_children_stage_infos(stage_info: StageInfo, scenario: Scenario) -> list[tuple[list[StageInfo], None]]:
        return [(stage_info.stages, None)]


class HostParentStageHandler(CommonParentStageHandler):
    def __init__(self, children, **params):
        super().__init__(children, **params)
        self.iteration_strategy = HostSequentialIterationStageStrategy()

    def execute_current_stage(self, stage_info, scenario, host, scenario_stage_info):
        active_stage, active_host_stage_info, active_scenario_stage_info = self.iteration_strategy.get_active_stage(
            stage_info, self.children, scenario_stage_info
        )
        log.info("Executing %s", active_host_stage_info)
        if active_host_stage_info.status == StageStatus.FINISHED:
            marker = Marker.success(message="The stage was completed manually")
        else:
            marker = active_stage.run(active_host_stage_info, scenario, host, active_scenario_stage_info)
        log.info("Got result %s for stage %s", marker, active_host_stage_info)
        scenario.set_host_stage_stage_info_update(
            active_scenario_stage_info.uid, host.uuid, marker.message, marker.status
        )
        return self.iteration_strategy.process_marker(marker, stage_info, active_host_stage_info)

    @staticmethod
    def get_all_children_stage_infos(
        stage_info: StageInfo, scenario: Scenario
    ) -> list[tuple[list[StageInfo], HostStageInfo]]:
        children_stage_infos = [(stage_info.stages, None)]
        for hsi in HostStageInfo.objects(scenario_id=scenario.scenario_id):
            children_stage_infos.append((hsi.stage_info.stages, hsi))
        return children_stage_infos


class BaseStage:
    allowed_stage_params = []

    def __init__(self, **stage_params):
        self.params = stage_params

    @property
    def name(self) -> str:
        # TODO(rocco66): replace to register key?
        return self.__class__.__name__

    def serialize(self, uid: str) -> StageInfo:
        return StageInfo(
            uid=uid, name=self.name, params=self.params, status_time=timestamp(), description=get_description(self)
        )

    @staticmethod
    def get_all_children_stage_infos(stage_info: StageInfo, scenario: Scenario) -> list[tuple[list[StageInfo], None]]:
        return [(stage_info.stages, None)]

    def __eq__(self, other: StageInfo) -> bool:
        return type(self) == type(other) and self.name == other.name and self.params == other.params

    def __repr__(self) -> str:
        format_str = "<{cls}(params={params!r})>"
        if type(self).__name__ != self.name:
            format_str = "<{cls}(name={name!r}, params={params!r})>"

        return format_str.format(cls=type(self).__name__, name=self.name, params=self.params)


class StageRunInterface:
    @abstractmethod
    def run(self, stage_info, scenario):
        pass

    def cleanup(self, stage_info, scenario):
        Marker.success(message="No cleanup action in StageRunInterface")


class HostStageRunInterface:
    @abstractmethod
    def run(self, stage_info, scenario, host, scenario_stage_info):
        pass

    def cleanup(self, stage_info, scenario, host, scenario_stage_info):  # noqa
        Marker.success(message="No cleanup action in HostStageRunInterface")


class HostGroupStageRunInterface:
    @abstractmethod
    def run(self, stage_info: StageInfo, scenario: Scenario, host_group_id: int):
        pass

    def cleanup(self, stage_info: StageInfo, scenario: Scenario, host_group_id: int):  # noqa
        Marker.success(message="No cleanup action in HostGroupStageRunInterface")


class Stage(BaseStage, StageRunInterface):

    LOG_TEMPLATE = "Scenario %s, stage '%s': %s"

    def log_info(self, scenario, msg):
        log.info(*self._log_args(scenario, msg))

    def log_error(self, scenario, msg):
        log.error(*self._log_args(scenario, msg))

    def _log_args(self, scenario, msg):
        return self.LOG_TEMPLATE, scenario.scenario_id, self.__class__.__name__, msg


class HostStage(BaseStage, HostStageRunInterface):

    LOG_TEMPLATE = "Scenario %s, host %s, stage '%s': %s"

    def log_info(self, scenario, host, msg):
        log.info(*self._log_args(scenario, host, msg))

    def log_error(self, scenario, host, msg):
        log.error(*self._log_args(scenario, host, msg))

    def _log_args(self, scenario, host, msg):
        return self.LOG_TEMPLATE, scenario.scenario_id, host.uuid, self.__class__.__name__, msg

    @staticmethod
    def get_ignore_cms_value(scenario: Scenario, host: Host, stage_ignore_cms_param: bool = False) -> bool:
        if stage_ignore_cms_param:  # Stage clearly demands to ignore cms
            return True

        project = host.get_project()  # if project cms is default cms - we should ignore it
        if project.cms_settings[0].cms == DEFAULT_CMS_NAME:
            return True

        try:  # try to get value from maintenance plot, in case of any error, return False and go to cms
            maintenance_plot_model = project.get_maintenance_plot()
            maintenance_plot = maintenance_plot_model.as_dataclass()
            settings = maintenance_plot.get_scenario_settings(scenario.scenario_type)
            return settings.ignore_cms_on_host_operations
        except MaintenancePlotScenarioSettingsTypeDoesntExists:
            return False
        except Exception as e:
            log.error(
                "unknown exception during getting settings for scenario %s of type %s with maintenance plot id %s for host %s: %s",
                scenario.scenario_id,
                scenario.scenario_type,
                project.maintenance_plot_id,
                host.human_id(),
                e,
            )
            return False


class HostGroupStage(BaseStage, HostGroupStageRunInterface):

    LOG_TEMPLATE = "Scenario %s, host_group_id %d, stage '%s': %s"

    def log_info(self, scenario: Scenario, host_group_id: int, msg: str):
        log.info(*self._log_args(scenario, host_group_id, msg))

    def log_error(self, scenario: Scenario, host_group_id: int, msg: str):
        log.error(*self._log_args(scenario, host_group_id, msg))

    def _log_args(self, scenario: Scenario, host_group_id: int, msg: str):
        return self.LOG_TEMPLATE, scenario.scenario_id, host_group_id, self.__class__.__name__, msg


class HostStageWithSharedData(HostStage):
    def __init__(self, shared_data_key, **params):
        super().__init__(shared_data_key=shared_data_key, **params)
        self.shared_data_key = shared_data_key


class MultiActionStage(BaseStage):
    def __init__(self, iteration_strategy, **params):
        super().__init__(**params)
        self.iteration_strategy = iteration_strategy

    def run(self, stage_info, scenario, *args, **kwargs):
        func = self.iteration_strategy.get_current_func(stage_info)
        marker = func(stage_info, scenario, *args, **kwargs)

        return self.iteration_strategy.make_transition(marker, stage_info, scenario)

    def cleanup(self, stage_info, scenario, *args, **kwargs):
        Marker.success(message="No cleanup action in MultiActionStage")


def get_description(stage):
    description = stage.__doc__
    if description:
        result = []
        for line in description.strip(" \n").split("\n"):
            line = line.strip(" \n")
            if line:
                result.append(line)
        return " ".join(result)
    else:
        return DEFAULT_STAGE_DESCRIPTION
