import typing
from datetime import datetime

from sepelib.core import constants
from walle.hosts import Host
from walle.maintenance_plot.model import MaintenancePlot
from walle.maintenance_plot.scenarios_settings.base import BaseScenarioMaintenancePlotSettings
from walle.models import timestamp
from walle.scenario import common
from walle.scenario.common import get_host_group_maintenance_plot
from walle.scenario.constants import StageName
from walle.scenario.data_storage.base import BaseScenarioDataStorage
from walle.scenario.definitions.base import get_data_storage
from walle.scenario.marker import Marker, MarkerStatus
from walle.scenario.mixins import HostGroupStage, HostStage
from walle.scenario.scenario import Scenario
from walle.scenario.stage_info import StageInfo, StageRegistry

MAINTENANCE_PLOT_ID_TYPE = typing.get_type_hints(MaintenancePlot).get("id")
REQUEST_CMS_X_SECONDS_BEFORE_MAINTENANCE_START_TIME_TYPE = typing.get_type_hints(
    BaseScenarioMaintenancePlotSettings
).get("request_cms_x_seconds_before_maintenance_start_time")


class CommonWaitBeforeRequestingCms:

    # Cache value from maintenance plot for 15 minutes to avoid unnesessary load on database.
    _cache_field_name = "cache_maintenance_plot_data"
    _cache_ttl = constants.MINUTE_SECONDS * 15

    def run(self, stage_info: StageInfo, scenario: Scenario, host_group_id: int, time_field_name: str) -> MarkerStatus:
        now = timestamp()
        data_storage = get_data_storage(scenario)
        scenario_parameters = data_storage.read_scenario_parameters()

        if scenario_parameters.maintenance_start_time is None:
            return Marker.success(message="Maintenance start time is not set in scenario parameters.")

        maintenance_plot_id, required_offset = self._get_required_offset(
            stage_info, scenario, host_group_id, data_storage, now, time_field_name
        )

        if required_offset is None:
            message = "Waiting before requesting CMS is not configured in maintenance plot '%s'." % maintenance_plot_id
            return Marker.success(message=message)

        wait_until_ts = scenario_parameters.maintenance_start_time - required_offset
        wait_until_str = datetime.fromtimestamp(wait_until_ts).strftime("%d.%m.%Y %H:%M")

        if now < wait_until_ts:
            return Marker.in_progress(message="Waiting until '%s'." % wait_until_str)

        return Marker.success(message="Had waited until '%s'." % wait_until_str)

    def _get_required_offset(
        self,
        stage_info: StageInfo,
        scenario: Scenario,
        host_group_id: int,
        data_storage: BaseScenarioDataStorage,
        now: int,
        time_field_name: str,
    ) -> (MAINTENANCE_PLOT_ID_TYPE, REQUEST_CMS_X_SECONDS_BEFORE_MAINTENANCE_START_TIME_TYPE):
        cache = stage_info.get_data(self._cache_field_name)

        if cache is None or cache["valid_until"] < now:
            maintenance_plot_id, required_offset = self._read_required_offset_from_maintenance_plot(
                scenario, host_group_id, data_storage, time_field_name
            )
            stage_info.set_data(
                self._cache_field_name,
                {
                    "maintenance_plot_id": maintenance_plot_id,
                    "required_offset": required_offset,
                    "valid_until": now + self._cache_ttl,
                },
            )
            return maintenance_plot_id, required_offset

        return cache["maintenance_plot_id"], cache["required_offset"]

    def _read_required_offset_from_maintenance_plot(
        self, scenario: Scenario, host_group_id: int, data_storage: BaseScenarioDataStorage, time_field_name: str
    ) -> (MAINTENANCE_PLOT_ID_TYPE, REQUEST_CMS_X_SECONDS_BEFORE_MAINTENANCE_START_TIME_TYPE):
        maintenance_plot = get_host_group_maintenance_plot(host_group_id, data_storage)
        scenario_settings = maintenance_plot.get_scenario_settings(scenario.scenario_type)
        return maintenance_plot.id, self._get_wait_time_value(scenario, scenario_settings, time_field_name)

    def _get_wait_time_value(
        self,
        scenario: Scenario,
        scenario_settings: typing.Type[BaseScenarioMaintenancePlotSettings],
        time_field_name: str,
    ):
        try:
            return getattr(scenario_settings, time_field_name)
        except AttributeError as e:
            msg = "Settings of scenario with type '{}' does not have option '{}': '{}'"
            raise RuntimeError(msg.format(scenario.scenario_type, time_field_name, str(e)))


class WaitUsingYpSla:

    _cache_field_name = "start_timestamp"

    def run(self, stage_info: StageInfo, scenario: Scenario) -> MarkerStatus:
        data_storage = get_data_storage(scenario)
        scenario_parameters = data_storage.read_scenario_parameters()

        if scenario_parameters.maintenance_start_time is None:
            return Marker.success(message="Maintenance start time is not set in scenario parameters, starting now.")

        start_timestamp = stage_info.get_data(self._cache_field_name)
        if not start_timestamp:
            start_timestamp = self._get_start_timestamp(scenario_parameters.maintenance_start_time)
            stage_info.set_data(self._cache_field_name, start_timestamp)

        now = timestamp()
        wait_until_str = datetime.fromtimestamp(start_timestamp).strftime("%d.%m.%Y %H:%M")

        if now < start_timestamp:
            return Marker.in_progress(message="Waiting until '%s'." % wait_until_str)

        return Marker.success(message="Had waited until '%s'." % wait_until_str)

    def _get_start_timestamp(self, maintenance_start_time: int) -> int:
        return common.get_request_to_maintenance_starting_time_using_yp_sla(maintenance_start_time)


@StageRegistry.register(StageName.HostGroupWaitBeforeRequestingCmsStage)
class HostGroupWaitBeforeRequestingCmsStage(HostGroupStage):
    """Waits until X hours left before maintenance start time."""

    _cache_field_name = "is_offset_used"

    def run(self, stage_info: StageInfo, scenario: Scenario, host_group_id: int):
        is_offset_used = self._get_workflow_type(stage_info, scenario, host_group_id)

        if is_offset_used:
            return CommonWaitBeforeRequestingCms().run(
                stage_info, scenario, host_group_id, "request_cms_x_seconds_before_maintenance_start_time"
            )
        else:
            return WaitUsingYpSla().run(stage_info, scenario)

    def _get_workflow_type(self, stage_info: StageInfo, scenario: Scenario, host_group_id: int) -> bool:
        is_offset_used = stage_info.get_data(self._cache_field_name)

        if is_offset_used is None:
            data_storage = get_data_storage(scenario)
            maintenance_plot = get_host_group_maintenance_plot(host_group_id, data_storage)
            scenario_settings = maintenance_plot.get_scenario_settings(scenario.scenario_type)

            try:  # temp fallback
                is_offset_used = not getattr(scenario_settings, "use_yp_sla")
            except Exception:
                return True

            stage_info.set_data(self._cache_field_name, is_offset_used)

        return is_offset_used


@StageRegistry.register(StageName.HostWaitBeforeRequestingCmsStage)
class HostWaitBeforeRequestingCmsStage(HostStage):
    """Waits until X hours left before maintenance start time."""

    def __init__(self, time_field_name="request_cms_x_seconds_before_maintenance_start_time"):
        super().__init__(time_field_name=time_field_name)
        self.time_field_name = time_field_name

    def run(self, stage_info: StageInfo, scenario: Scenario, host: Host, scenario_stage_info: StageInfo):
        host_info = scenario.get_host_info_by_host_obj(host)
        return CommonWaitBeforeRequestingCms().run(stage_info, scenario, host_info.group, self.time_field_name)
