"""Wall-E scenario operations."""

import logging
import typing as tp
from collections import defaultdict

from mongoengine import (
    StringField,
    DictField,
    LongField,
    IntField,
    EmbeddedDocumentField,
    EmbeddedDocument,
    MapField,
    BooleanField,
    ListField,
)

from sepelib.mongo import monotonic
from sepelib.mongo.util import register_model
from walle.errors import ErrorType, WalleError
from walle.hosts import Host
from walle.models import Document, timestamp
from walle.projects import get_by_id
from walle.scenario.constants import (
    ScenarioFsmStatus,
    HostScenarioStatus,
    WORK_STATUS_LABEL_NAME,
    ScenarioWorkStatus,
    ScriptName,
    FixedMaintenanceApproversLogins,
    WORKMATE_STATUS_LABEL_NAME,
    WORKMATE_STATUS_TARGET_LABEL_VALUE,
)
from walle.scenario.data_storage.types import HostGroupSource, MaintenanceApproversGroup
from walle.scenario.errors import DEFAULT_ERROR_ADVANCE_PERIOD, ScenarioValidationError
from walle.scenario.host_groups_builders.hosts_list_splitters import BaseHostsListSplitter
from walle.scenario.host_stage_info import HostStageInfo
from walle.scenario.marker import MarkerStatus
from walle.scenario.stage_info import StageInfo

log = logging.getLogger(__name__)

NEXT_CHECK_INCREMENT = 5


class ScenarioHostState(EmbeddedDocument):
    inv = IntField(min_value=0, help_text="Inventory number")
    status = StringField(default=HostScenarioStatus.QUEUE, help_text="Host status in scenario")
    timestamp = IntField(help_text="Time when the current state has been set")
    group = IntField(help_text="Group defined by scheduler")
    is_acquired = BooleanField(default=False)  # TODO naming
    enabled_checks = ListField(
        field=StringField(help_text="Check name"), default=None, help_text="List of checks for use"
    )


class StageError(EmbeddedDocument):
    id = StringField(required=True, help_text="Error id")
    repeats = IntField(required=True, default=0, help_text="Repeats of error")
    type = StringField(required=True, help_text="Error type")
    is_visible = BooleanField(default=False)
    retry_period = IntField(default=DEFAULT_ERROR_ADVANCE_PERIOD)
    exc_info = DictField(default=None, help_text="Exception info")
    stacktrace = ListField(help_text="Exception stacktrace")


# Maybe refactor it to `ScenarioModel`.
@register_model
class Scenario(Document):
    """Represents an scenario record."""

    #  WALLE-4109: Batch update of Mongoengine Document is much-much faster
    def __init__(self, *args, **kwargs):
        Document.__init__(self, *args, **kwargs)
        self.map_stage_uid_to_stage_info_hosts_update = defaultdict(lambda: defaultdict(dict))
        # "map_stage_uid_to_stage_info_hosts_update": {
        #     "host_uuid": {
        #         <data>
        #     }
        # }

    scenario_id = LongField(required=True, primary_key=True, help_text="Scenario ID (increments automatically)")
    name = StringField(required=True, min_length=1, max_length=100, help_text="Name", unique=True)
    scenario_type = StringField(required=True, choices=ScriptName.ALL_EXISTING, help_text="Scenario type")
    issuer = StringField(required=True, help_text="Issuer of scenario")
    status = StringField(required=True, default=ScenarioFsmStatus.CREATED, help_text="Scenario status")
    message = StringField(help_text="Scenario message for user")
    ticket_key = StringField(required=False, help_text="Related ticket key", unique=False)

    next_check_time = LongField(required=True, default=0, help_text="Timestamp the next check should be processed at")
    action_time = LongField(required=True, default=0, help_text="Timestamp of last action")
    creation_time = LongField(required=True, help_text="Timestamp of scenario creation")

    labels = DictField(help_text="Field that contains everything that the user send")  # WALLE-2435
    stage_info = EmbeddedDocumentField(StageInfo, required=True, help_text="State of stages")
    errors = MapField(field=EmbeddedDocumentField(StageError), help_text="Current scenario errors")

    hosts = MapField(
        field=EmbeddedDocumentField(ScenarioHostState), help_text="Current states of operations under hosts"
    )
    current_group = IntField(help_text="Currently processing group of hosts")

    revision = LongField(required=True, default=0, help_text="Revision of scenario")
    uses_uuid_keys = BooleanField(default=False)

    data_storage = DictField(help_text="Scenario-specific data storage")

    # FIXME Remove after migration
    script_args = DictField(default=None, help_text="Additional script arguments")

    api_fields = (
        "scenario_id",
        "name",
        "scenario_type",
        "issuer",
        "status",
        "ticket_key",
        "action_time",
        "creation_time",
        "labels",
        "script_args",
        "hosts",
        "errors",
        "message",
        "data_storage",
    )

    meta = {
        "collection": "scenarios",
        "indexes": [
            {"name": "name", "fields": ["name"], "unique": True},
            {"name": "ticket_key", "fields": ["ticket_key"], "unique": False},
            {"name": "labels", "fields": ["labels"], "unique": False},
        ],
    }

    def __str__(self):
        return "<Scenario #{}: \"{}\">".format(self.id, self.scenario_type)

    def _set_statuses(self, fsm_status, work_status):
        update_query = Scenario.set_status_kwargs(fsm_status)
        update_query["set__labels__{}".format(WORK_STATUS_LABEL_NAME)] = work_status
        self.modify(**update_query)

    def mark_as_finished(self):
        log.info("Scenario %s (#%s) has finished", self.name, self.id)
        self._set_statuses(ScenarioFsmStatus.FINISHED, ScenarioWorkStatus.FINISHED)

    def mark_as_cancelled(self):
        log.info("Scenario %s (#%s) has been cancelled", self.name, self.id)
        self._set_statuses(ScenarioFsmStatus.CANCELED, ScenarioWorkStatus.CANCELED)

    def start(self):
        self._set_statuses(ScenarioFsmStatus.STARTED, ScenarioWorkStatus.STARTED)

    def cancel(self):
        self._set_statuses(ScenarioFsmStatus.CANCELING, ScenarioWorkStatus.CANCELING)

    def pause(self):
        self.modify(**Scenario.set_status_kwargs(ScenarioFsmStatus.PAUSED))

    def dismiss_hosts(self):
        Host.objects(scenario_id=self.scenario_id).update(unset__scenario_id=True)
        HostStageInfo.objects(scenario_id=self.scenario_id).delete()

    def clean_hosts_stage_info(self):
        HostStageInfo.objects(scenario_id=self.scenario_id).delete()

    def get_invs_of_hosts(self) -> tp.List[int]:
        if self.uses_uuid_keys:
            return list(Host.objects(uuid__in=list(self.hosts.keys())).values_list("inv"))
        else:
            return [int(host) for host in self.hosts.keys()]

    def get_hosts_queryset(self) -> tp.Iterable[Host]:
        if self.uses_uuid_keys:
            return Host.objects(uuid__in=list(self.hosts.keys()))
        else:
            return Host.objects(inv__in=list(self.hosts.keys()))

    def get_host_info_by_host_obj(self, host: Host) -> ScenarioHostState:
        key = host.uuid if self.uses_uuid_keys else host.inv
        return self.hosts[key]

    @staticmethod
    def next_id() -> int:
        return monotonic.get_next("scenario_id")

    @staticmethod
    def get_new_next_check_time() -> int:
        next_check = timestamp() + NEXT_CHECK_INCREMENT
        return next_check

    # TODO: Possible update revision?
    @staticmethod
    def set_status_kwargs(status: str) -> tp.Dict[str, tp.Any]:
        kwargs = dict(set__status=status, set__action_time=timestamp())
        return kwargs

    def set_host_status_kwargs(self, host: Host, status: str) -> tp.Dict[str, tp.Union[bool, str, int]]:
        host_key = host.uuid if self.uses_uuid_keys else host.inv
        kwargs = {
            "set__hosts__{}__is_acquired".format(host_key): True,
            "set__hosts__{}__status".format(host_key): status,
            "set__hosts__{}__timestamp".format(host_key): timestamp(),
        }
        return kwargs

    def set_host_info_status(self, host_key: str, status: str):
        self.hosts[host_key].is_acquired = True
        self.hosts[host_key].status = status
        self.hosts[host_key].timestamp = timestamp()

    def work_completed_by_workmate(self) -> bool:
        # NOTE(rocco66): workmate is 'NOC' or 'ITDC'
        return self.labels.get(WORKMATE_STATUS_LABEL_NAME) == WORKMATE_STATUS_TARGET_LABEL_VALUE

    def set_host_info_checks(self, host_key: str, checks: tp.List[str]):
        self.hosts[host_key].enabled_checks = checks

    def get_works_status(self):
        return self.labels[WORK_STATUS_LABEL_NAME]

    def set_works_status_label(self, status: str):
        self.labels[WORK_STATUS_LABEL_NAME] = status
        self.action_time = timestamp()

    @staticmethod
    def create_list_of_host_states(
        invs: tp.List[int], resolve_uuids: bool = False, hosts_group_id_map: dict = None
    ) -> tp.Dict[str, ScenarioHostState]:
        if resolve_uuids:
            return {
                host.uuid: ScenarioHostState(
                    inv=host.inv,
                    timestamp=timestamp(),
                    status=HostScenarioStatus.QUEUE,
                    group=hosts_group_id_map.get(host.inv) if hosts_group_id_map is not None else 0,
                )
                for host in Host.objects.filter(inv__in=invs).only("uuid", "inv")
            }
        else:
            return {
                str(inv): ScenarioHostState(
                    inv=inv,
                    timestamp=timestamp(),
                    status=HostScenarioStatus.QUEUE,
                    group=hosts_group_id_map.get(inv) if hosts_group_id_map is not None else 0,
                )
                for inv in invs
            }

    @staticmethod
    def split_scenario_hosts_to_groups(
        host_inv_list: tp.List[int], uses_uuids: bool, hosts_list_splitter: BaseHostsListSplitter
    ) -> (tp.Dict[str, ScenarioHostState], tp.List[HostGroupSource]):
        hosts_group_id_map = {}
        host_groups_sources = []

        groups = hosts_list_splitter.split(host_inv_list)

        current_group_id = 0
        for group in groups:
            host_groups_sources.append(HostGroupSource(current_group_id, group.group_source))
            for host_inv in group.hosts_invs:
                hosts_group_id_map[host_inv] = current_group_id
            current_group_id += 1

        hosts_states = Scenario.create_list_of_host_states(
            host_inv_list, resolve_uuids=uses_uuids, hosts_group_id_map=hosts_group_id_map
        )
        return hosts_states, host_groups_sources

    @staticmethod
    def get_maintenance_approvers_groups_from_host_groups_sources(
        host_groups_sources: tp.List[HostGroupSource],
    ) -> tp.List[MaintenanceApproversGroup]:
        return [
            MaintenanceApproversGroup(
                group_id=group.group_id,
                name=group.source.get_group_source_name(),
                logins=FixedMaintenanceApproversLogins.DEFAULT_MAINTENANCE_APPROVERS_LOGINS,
            )
            for group in host_groups_sources
        ]

    def to_api_obj(
        self,
        requested_fields: tp.Optional[tp.List[str]] = None,
        extra_fields: tp.Optional[tp.List[str]] = None,
        *args,
        **kwargs
    ) -> tp.Dict[tp.Any, tp.Any]:
        hosts_field = Scenario.hosts.name
        errors_field = Scenario.errors.name

        if extra_fields is None:
            extra_fields = {}

        if requested_fields is None or hosts_field in requested_fields:
            # flat view for UI: WALLE-2270
            extra_fields.update(
                {
                    hosts_field: [host_info.to_mongo().to_dict() for host_info in self.hosts.values()]
                    if self.hosts
                    else []
                }
            )

        if requested_fields is None or errors_field in requested_fields:
            if self.errors:
                extra_fields.update(
                    {
                        errors_field: {
                            error_id: error_info.to_mongo().to_dict() for error_id, error_info in self.errors.items()
                        }
                    }
                )

        return super().to_api_obj(requested_fields, extra_fields)

    def get_scenario_error(self, exc_id: str, raw_error: Exception) -> StageError:
        if exc_id in self.errors:
            return self.errors[exc_id]

        if isinstance(raw_error, WalleError):
            error_type = raw_error.error_type
        else:
            error_type = ErrorType.FIXABLE

        return StageError(id=exc_id, type=error_type)

    @staticmethod
    def validate_name(scenario_name: str):
        if not scenario_name:
            raise ScenarioValidationError("Scanario name is empty")
        if len(scenario_name) < Scenario.name.min_length or len(scenario_name) > Scenario.name.max_length:
            raise ScenarioValidationError(
                "Scanario name length must be between {} and {}".format(
                    Scenario.name.min_length, Scenario.name.max_length
                )
            )

    def set_host_stage_stage_info_update(self, stage_info_uid, host_uuid, msg, marker_status):
        host_status = HostStageStatus.FINISHED if marker_status == MarkerStatus.SUCCESS else HostStageStatus.PROCESSING

        self.map_stage_uid_to_stage_info_hosts_update[stage_info_uid][host_uuid] = {
            "status_time": timestamp(),
            "status": host_status,
            "msg": msg,
        }

    def update_stage_info_hosts_for_all_child_stages(self):
        def _get_stage_info_by_uid(stage_info: StageInfo, uids: list[int]):
            cur_uid = uids.pop(0)
            child_stage_info = stage_info.stages[cur_uid]
            if not uids:
                return child_stage_info
            else:
                return _get_stage_info_by_uid(child_stage_info, uids)

        try:
            for uid, update in self.map_stage_uid_to_stage_info_hosts_update.items():
                uid_in_list_fmt = [int(num) for num in uid.split(".")]
                stage_info_to_update = _get_stage_info_by_uid(self.stage_info, uid_in_list_fmt[1:])
                stage_info_to_update.hosts.update(update)
        except Exception:  # dont want to fail on any strange cases
            pass


def get_map_maintenance_plot_ids_to_hosts_count(scenario_id: int):
    scenario = Scenario.objects.get(scenario_id=scenario_id)
    hosts = Host.objects(inv__in=[host_info.inv for host_info in scenario.hosts.values()]).only("project")
    projects_to_hosts_count = defaultdict(int)
    for host in hosts:
        projects_to_hosts_count[host.project] += 1

    maintenance_plot_id_to_hosts_count = defaultdict(int)
    for project_id, hosts_count in projects_to_hosts_count.items():
        project = get_by_id(project_id)
        maintenance_plot_id_to_hosts_count[project.maintenance_plot_id] += hosts_count

    return maintenance_plot_id_to_hosts_count


class HostStageStatus:
    PROCESSING = "processing"
    FINISHED = "finished"
    ALL = [PROCESSING, FINISHED]
