import logging
from collections.abc import Callable

from walle.locks import ScenarioInterruptableLock
from walle.scenario.constants import ScenarioModifyAction
from walle.scenario.constants import StageName
from walle.scenario.mixins import BaseStage
from walle.scenario.scenario import Scenario
from walle.scenario.stage.scheduler_stage import HostSchedulerStage
from walle.scenario.stage_info import StageInfo, StageStatus
from walle.scenario.utils import BaseRegistry

log = logging.getLogger(__name__)


class ScenarioModifyActionRegistry(BaseRegistry):
    ITEMS = {}


def modify_stage_info_by_uid(scenario: Scenario, uid: str, modify_function: Callable) -> str:
    def _modify(stage: BaseStage, stage_info: StageInfo, uid_signs: list[int]):
        if not uid_signs:
            modify_function(stage_info)
            return stage_info
        child_id = uid_signs.pop(0)
        if isinstance(stage, HostSchedulerStage):
            child_stage = stage.host_root_stage
        else:
            child_stage = stage.children[child_id]
        all_stage_infos_and_containers = stage.get_all_children_stage_infos(stage_info, scenario)
        for stage_info_list, container in all_stage_infos_and_containers:
            child_stage_info = stage_info_list[child_id]
            _modify(child_stage, child_stage_info, uid_signs[::])
            stage_info_list[child_id] = child_stage_info
            if container:
                container.stage_info.stages = stage_info_list
                container.save()
        return stage_info

    with ScenarioInterruptableLock(scenario.scenario_id):
        root_stage_info = scenario.stage_info
        root_stage = scenario.stage_info.deserialize()
        uid_digits = [int(digit) for digit in uid.split(".")]

        try:
            if uid_digits.pop(0) != 0:
                raise IndexError

            scenario.stage_info = _modify(root_stage, root_stage_info, uid_digits)
            scenario.save()
            log.info("Finish modifiyng uid: %s", uid)
            return "Scenario was modified successfully"
        except IndexError:
            return "Root level has only stage with id 0"
        except Exception as e:
            return "Unexpected error: {}".format(e)


@ScenarioModifyActionRegistry.register(ScenarioModifyAction.SKIP_WAIT_STAGE_FOR_HOSTS_TRANSFER)
def skip_wait_stage_for_hosts_transfer(scenario: Scenario):
    def modify_function(stage_info: StageInfo):
        stage_info.status = StageStatus.FINISHED

    wait_stages_uids = _get_uuids_of_all_stages_with_given_name(scenario.stage_info, StageName.WaitForLabelOrTimeStage)
    for uid in wait_stages_uids:
        log.info("Get uuids: %s", ", ".join(wait_stages_uids))
        scenario = Scenario.objects.get(scenario_id=scenario.scenario_id)
        modify_stage_info_by_uid(scenario, uid, modify_function)


def _get_uuids_of_all_stages_with_given_name(stage_info: StageInfo, stage_name: str) -> list[str]:
    uids = []
    queue = [stage_info]
    while queue:
        processing_stage_info = queue.pop()
        if processing_stage_info.name == stage_name:
            uids.append(processing_stage_info.uid)
        queue.extend(processing_stage_info.stages)
    return uids
