"""Host task stages management."""

import itertools
import logging
from collections import namedtuple

from mongoengine import EmbeddedDocument, StringField, LongField, ListField, DictField, EmbeddedDocumentField

from sepelib.core.exceptions import Error, LogicalError
from walle.models import timestamp

log = logging.getLogger(__name__)


class Stages:
    """Stage name constants to use in code for easier searching of stage usage."""

    ACQUIRE_PERMISSION = "acquire-permission"
    SET_DOWNTIME = "set-downtime"
    SET_MAINTENANCE = "set-maintenance"
    SET_ASSIGNED = "set-assigned"
    SET_PROBATION = "set-probation"

    SET_HOSTNAME = "set-hostname"
    ALLOCATE_HOSTNAME = "allocate-hostname"
    ASSIGN_HOSTNAME = "assign-hostname"
    ASSIGN_BOT_PROJECT = "assign-bot-project"
    WAIT_FOR_BOT_PREORDER = "wait-for-bot-preorder"

    REBOOT = "reboot"
    SSH_REBOOT = "ssh-reboot"
    SSH_REBOOT_COMPOSITE = "ssh-reboot-composite"
    KEXEC_REBOOT = "kexec-reboot"
    KEXEC_REBOOT_COMPOSITE = "kexec-reboot-composite"
    POWER_OFF = "power-off"
    POWER_ON = "power-on"
    POWER_ON_COMPOSITE = "power-on-composite"

    LOG_COMPLETED_OPERATION = "log-completed-operation"
    CANCEL_ADMIN_REQUESTS = "cancel-admin-requests"
    RESET_HEALTH_STATUS = "reset-health-status"

    PROFILE = "profile"
    EINE_PROFILE = "eine-profile"
    DROP_EINE_PROFILE = "drop-eine-profile"

    HW_REPAIR = "hw-repair"
    RESET_BMC = "reset-bmc"

    HEAL_DISK = "heal-disk"
    CHANGE_DISK = "change-disk"

    NETWORK = "network"
    WAIT_FOR_ACTIVE_MAC = "wait-for-active-mac"
    WAIT_FOR_SWITCH_PORT = "wait-for-switch-port"
    UPDATE_NETWORK_LOCATION = "update-network-location"
    SWITCH_VLANS = "switch-vlans"
    SETUP_DNS = "setup-dns"
    ADD_HOST_TO_CAUTH = "add-host-to-cauth"
    VERIFY_NETWORK_LOCATION = "verify-network-location"
    VERIFY_SWITCH_PORT = "verify-switch-port"

    DEPLOY = "deploy"
    ISSUE_CERTIFICATE = "issue-certificate"
    ASSIGN_LUI_CONFIG = "assign-lui-config"
    GENERATE_CUSTOM_DEPLOY_CONFIG = "generate-custom-deploy-config"
    LUI_INSTALL = "lui-install"
    LUI_SETUP = "lui-setup"
    LUI_REMOVE = "lui-remove"
    LUI_DEACTIVATE = "lui-deactivate"

    COMPLETE_PREPARING = "complete-preparing"
    COMPLETE_RELEASING = "complete-releasing"
    COMPLETE_DELETION = "complete-deletion"
    SWITCH_PROJECT = "switch-project"

    REPORT = "report"
    REPORT_RACK = "report-rack"
    REPORT_RACK_OVERHEAT = "report-rack-overheat"

    MONITOR = "monitor"
    DEACTIVATE = "deactivate"

    DROP_CMS_TASK = "drop-cms-task"
    SWITCH_DEFAULT_CMS_PROJECT = "switch-default-cms-project"

    FQDN_DEINVALIDATION = "fqdn-deinvalidation"
    PROVIDE_DIAGNOSTIC_HOST_ACCESS = "diagnostic-host-access"

    CLOUD_POST_PROCESSOR = "cloud-post-processor"

    # Foo and Bar stages to run FSM without doing anything to host
    FOO = "foo"
    BAR = "bar"


class StageTerminals:
    # General purpose, common to all stages.
    SUCCESS = "success"
    """Terminate stage that was completed successfully."""
    FAIL = "failure"
    """Terminate stage that was failed to complete."""
    CANCEL = "cancel"
    """Terminate stage for cancelled task."""
    SKIP = "skip"
    """Terminate stage that is failed but not required."""
    RETRY = "retry"
    """Retry stage after some grace period."""
    RETRY_ACTION = "retry-action"
    """Retry whole action (e.g. retry parent stage) after some grace period."""

    # General purpose, available on demand
    COMPLETE_PARENT = "complete-parent"
    """Complete parent stage skipping the rest of nested stages."""
    PROFILE = "profile"
    """Terminate stage and upgrade task to profile."""
    HIGHLOAD_AND_REDEPLOY = "highload"
    """Terminate stage and upgrade the task to highload eaas profile with redeploy."""
    DISK_RW_AND_REDEPLOY = "disk-rw and redeploy"
    """Terminate stage and upgrade the task to disk-rw eaas profile with redeploy."""
    DELETE_HOST = "delete-host"
    """Terminate stage and delete host."""

    # Custom terminators
    SWITCH_MISSING = "switch-missing"
    """Terminate VLAN switching stage when host's switch can not be identified."""
    DEPLOY_FAILED = "deploy-failed"
    """Terminate deploy task when deploy fails."""
    NO_ERROR_FOUND = "no-error-found"
    """Terminate deploy extra-highload task no errors found during profiling."""

    DEFAULTS = [SUCCESS, SKIP, CANCEL, FAIL]
    ALL = DEFAULTS + [
        COMPLETE_PARENT,
        PROFILE,
        HIGHLOAD_AND_REDEPLOY,
        DELETE_HOST,
        SWITCH_MISSING,
        DEPLOY_FAILED,
        DISK_RW_AND_REDEPLOY,
        NO_ERROR_FOUND,
        RETRY_ACTION,
    ]


class MissingStageIdError(Error):
    def __init__(self, uid):
        super().__init__("Invalid stage ID: {}.", uid)


class UidNotSet(Error):
    pass


class _StageUidPromise:
    def __init__(self, stage):
        self.stage = stage

    def to_mongo(self, *args):
        """Convert value to mongo-compatible. Executed right before saving data into a database."""
        if self.stage.uid is None:
            raise UidNotSet("uid is not set for stage {}".format(self.stage.name))

        return self.stage.uid

    def __str__(self):
        return self.to_mongo()


class Stage(EmbeddedDocument):
    _UNDEFINED = object()

    uid = StringField(required=True, help_text="Stage UID")
    name = StringField(required=True, help_text="Stage name")
    params = DictField(default=None, help_text="Optional parameters")

    status = StringField(help_text="Current stage status")
    status_time = LongField(help_text="Time when the current status has been set")
    data = DictField(default=None, help_text="Stage persistent data storage")
    temp_data = DictField(default=None, help_text="Stage temporary data storage (cleared on leaving the stage)")
    terminators = DictField(
        field=StringField(choices=StageTerminals.ALL),
        default=None,
        help_text="Overwrite default stage terminate handlers",
    )

    stages = ListField(EmbeddedDocumentField("self"), default=None, help_text="Child stages")

    def get_uid(self):
        return _StageUidPromise(self)

    @property
    def description(self):
        if self.status is None:
            return self.name
        else:
            return self.name + ":" + self.status

    def timed_out(self, timeout, key=None):
        if key is not None:
            start_time = self.get_temp_data(key)
        else:
            start_time = self.status_time

        return timestamp() - start_time >= timeout

    def has_param(self, name):
        return self.params and name in self.params

    def get_param(self, name, default=_UNDEFINED):
        return self._get("parameter", self.params, name, default)

    def set_param(self, name, value):
        if self.params is None:
            self.params = {}

        self.params[name] = value
        return value

    def has_data(self, name):
        return self.data and name in self.data

    def get_data(self, name, default=_UNDEFINED):
        return self._get("data", self.data, name, default)

    def set_data(self, name, value):
        if self.data is None:
            self.data = {}

        self.data[name] = value
        return value

    def setdefault_data(self, name, value):
        missing = object()

        current_value = self.get_data(name, missing)
        if current_value is not missing:
            return current_value

        return self.set_data(name, value)

    def has_temp_data(self, name):
        return self.temp_data and name in self.temp_data

    def get_temp_data(self, name, default=_UNDEFINED):
        return self._get("temporary data", self.temp_data, name, default)

    def set_temp_data(self, name, value):
        if self.temp_data is None:
            self.temp_data = {}

        self.temp_data[name] = value
        return value

    def setdefault_temp_data(self, name, value):
        missing = object()

        current_value = self.get_temp_data(name, missing)
        if current_value is not missing:
            return current_value

        return self.set_temp_data(name, value)

    def del_temp_data(self, name):
        if self.has_temp_data(name):
            del self.temp_data[name]

    def _get(self, data_name, data, name, default=_UNDEFINED):
        if data:
            try:
                return data[name]
            except KeyError:
                pass

        if default is self._UNDEFINED:
            raise Error("'{}' {} is not defined for '{}' stage.", name, data_name, self.name)

        return default

    def __repr__(self):
        return "<{}: id={}, name='{}'>".format(self.__class__.__name__, id(self), self.name)


def get_by_uid(stages, uid):
    return _get_by_ids(stages, _parse_uid(uid))


def get_by_name(stages, *names):
    """
    Returns stage by name or names. If multiple names are given, returns stages by path name1/name2/name3.
    :return: Returns either Stage instance or None.
    """
    names = list(names)
    while names:
        name = names.pop(0)
        for stage in stages:
            if stage.name == name:
                if not names:
                    return stage
                else:
                    stages = stage.stages or []
                    break


def get_parent(stages, uid):
    stage_ids = _parse_uid(uid)
    if len(stage_ids) == 1:
        return None

    return _get_by_ids(stages, stage_ids[:-1])


def get_next(stages, uid):
    stage_ids = _parse_uid(uid)
    stage_id = stage_ids[-1]

    if len(stage_ids) == 1:
        parent_stages = stages
    else:
        parent_stages = _get_by_ids(stages, stage_ids[:-1]).stages

    if parent_stages is None or len(parent_stages) < stage_id:
        raise MissingStageIdError(uid)

    return None if len(parent_stages) <= stage_id else parent_stages[stage_id]


def is_descendant(stage, parent_stage):
    return stage.uid.startswith(parent_stage.uid + ".")


def set_uids(stages, prefix=""):
    for idx, stage in enumerate(stages):
        stage.uid = prefix + str(idx + 1)
        if stage.stages:
            set_uids(stage.stages, stage.uid + ".")

    return stages


def iter_stages(stages):
    StageGroup = namedtuple("StageGroup", ("start_from", "stages"))

    pending_stage_groups = [StageGroup(0, stages)]

    while pending_stage_groups:
        start_from, stages = pending_stage_groups.pop()

        for stage_id, stage in enumerate(itertools.islice(stages, start_from, None), start_from):
            yield stage

            if stage.stages:
                pending_stage_groups.extend(
                    (
                        StageGroup(stage_id + 1, stages),
                        StageGroup(0, stage.stages),
                    )
                )
                break


def _parse_uid(uid):
    try:
        return [int(stage_id) for stage_id in uid.split(".")]
    except ValueError:
        raise Error("Invalid stage ID: {}.", uid)


def _format_ids(stage_ids):
    return ".".join(str(stage_id) for stage_id in stage_ids)


def _get_by_ids(stages, stage_ids):
    if not stage_ids:
        raise LogicalError

    for stage_id in stage_ids:
        if stages is None or len(stages) < stage_id:
            raise MissingStageIdError(_format_ids(stage_ids))

        stage = stages[stage_id - 1]
        stages = stage.stages

    return stage
