"""MongoDB model base."""

import time
from uuid import UUID

import mongoengine
from mongoengine import BooleanField, IntField, LongField, StringField, EmbeddedDocument, DictField, ListField
from mongoengine import ValidationError, EmbeddedDocumentField
from mongoengine.base.fields import BaseField

import sepelib.mongo.util

sepelib.mongo.util.patch()

import sepelib.mongo.document

from sepelib.core.exceptions import Error
from sepelib.mongo.util import register_model
from sepelib.util.misc import doesnt_override
from walle.authorization import has_iam
from walle.expert.types import FAILURE_LIMITS_MAP


class DocumentValidationError(Error):
    pass


class DocumentPostprocessor:
    """
    A base class for postprocessing objects that take a bunch of documents and transform them to API objects, possibly
    adding a new fields with data from other collections.

    Attention: `dict` fields aren't supported yet.
    """

    def __init__(self, extra_db_fields, extra_fields):
        """
        :param extra_db_fields: fields that are required for the preprocessor to work properly.
        :param extra_fields: extra fields that will be added to the result object.
        """

        self.extra_db_fields = set(extra_db_fields)
        self.extra_fields = set(extra_fields)

    def is_needed(self, requested_fields):
        """Returns True if the document preprocessing is required to display the requested fields."""

        return not requested_fields.isdisjoint(self.extra_fields)

    def process(self, iterable, requested_fields):
        """Consumes documents from the `iterable` and returns a list of generated API objects."""

        raise NotImplementedError

    def process_one(self, object, requested_fields):
        return self.process([object], requested_fields)[0]


class Document(sepelib.mongo.document.Document):
    """Base class for all our models."""

    meta = {"abstract": True}

    @classmethod
    @doesnt_override(mongoengine.Document)
    def validate_obj_fields(cls, obj):
        """Validates the specified dict object's fields against the schema."""

        for key, value in obj.items():
            try:
                field = getattr(cls, key)
                if not isinstance(field, BaseField):
                    raise AttributeError
            except AttributeError:
                raise DocumentValidationError("{} doesn't have {} field", cls.__name__, key)

            if value is not None:
                try:
                    field._validate(value)
                except ValidationError as e:
                    raise DocumentValidationError("Invalid {}['{}'] field: {}", cls.__name__.lower(), key, e)

        return obj

    @classmethod
    @doesnt_override(mongoengine.Document)
    def api_query_fields(klass, requested_fields=None):
        """Returns a list of fields that should be queried from MongoDB to return a requested API object."""

        requested_fields = get_requested_fields(klass, requested_fields)
        fields = [field for field in klass.api_fields if _filter_api_field(field, requested_fields)]

        id_field = klass._meta["id_field"]
        if id_field not in fields:
            fields.append(id_field)

        return fields

    @doesnt_override(mongoengine.Document)
    def to_api_obj(self, requested_fields=None, extra_fields=None, iam_public_handler=False):
        """Returns an API representation of the document.

        :param extra_fields - a field:value mapping of extra fields that should be available to API object. Extra fields
        may override existing fields. If value is `None`, the object's field will be deleted if exists.

        :param extra_fields - should we limit object fields for YC installation public handler

        Note: this method doesn't work properly with document inheritance and db_field=... for fields in
        EmbeddedDocument yet.

        Note: nested `extra_fields` aren't supported yet.

        TODO: this method doesn't work properly when user requests a field of an allowed document field.
        """

        requested_fields = get_requested_fields(self, requested_fields, iam_public_handler)

        api_obj = {}

        for field in self.api_fields:
            if not _filter_api_field(field, requested_fields):
                continue

            attrs = field.split(".")

            obj = self
            for attr in attrs:
                obj = getattr(obj, attr)
                if obj is None:
                    break

            if obj is not None:
                if isinstance(obj, EmbeddedDocument):
                    obj = obj.to_mongo().to_dict()
                elif isinstance(obj, list):
                    obj = [item.to_mongo().to_dict() if isinstance(item, EmbeddedDocument) else item for item in obj]
                elif isinstance(obj, UUID):
                    obj = str(obj)

                attr_obj = api_obj
                for attr in attrs[:-1]:
                    attr_obj = attr_obj.setdefault(attr, {})
                attr_obj[attrs[-1]] = obj

        if extra_fields:
            for field, value in extra_fields.items():
                if not _filter_api_field(field, requested_fields):
                    continue

                if value is None:
                    api_obj.pop(field, None)
                else:
                    api_obj[field] = value

        return api_obj


def _filter_api_field(field, requested_fields):
    for requested_field in requested_fields:
        if requested_field == field or field.startswith(requested_field + "."):
            return True

    return False


class FsmHandbrake(EmbeddedDocument):
    timeout_time = LongField(required=True, help_text="Time when handbrake expires")
    audit_log_id = StringField(help_text="ID of associated audit log entry")


class TimedLimitDocument(EmbeddedDocument):
    period = StringField(required=True, help_text="Time period")
    limit = LongField(required=True, help_text="Count of failures in period")


_LIMIT_NAMES = set(FAILURE_LIMITS_MAP.values())


def _timed_limits_validation(limits):
    return all(f in _LIMIT_NAMES for f in limits)


@register_model
class Settings(Document):
    """Stores dynamic settings."""

    id = StringField(primary_key=True, required=True)
    schema_version = IntField(required=True, default=0)

    disable_healing_automation = BooleanField(required=True, default=False)
    disable_dns_automation = BooleanField(required=True, default=False)
    failure_log_start_time = LongField(required=True, default=0)

    inventory_invalid_hosts_limit = IntField(required=False, default=None)
    inventory_updated_macs_hosts_limit = IntField(required=False, default=None)
    inventory_updated_inv_hosts_limit = IntField(required=False, default=None)
    fsm_handbrake = EmbeddedDocumentField(FsmHandbrake, required=False, default=None)
    scenario_fsm_handbrake = EmbeddedDocumentField(FsmHandbrake, required=False, default=None)

    checks_percentage_overrides = DictField(required=False, default={})
    global_timed_limits_overrides = DictField(
        field=EmbeddedDocumentField(TimedLimitDocument),
        validation=_timed_limits_validation,
        required=False,
        default={},
    )
    dmc_rules_switched = ListField(
        StringField(min_length=1),
        required=False,
        default=[],
        help_text="For DMC rules migration, see WALLE-4645",
    )

    api_fields = ("disable_healing_automation", "disable_dns_automation")


def get_requested_fields(klass, requested_fields=None, iam_public_handler=False):
    if requested_fields is None:
        requested_fields = set(getattr(klass, "default_api_fields", klass.api_fields))
    else:
        requested_fields = set(requested_fields)

    if has_iam() and iam_public_handler:
        if iam_public_fields := set(getattr(klass, "iam_public_fields", [])):
            requested_fields &= iam_public_fields

    return requested_fields


def timestamp():
    """Returns current timestamp."""

    return _timestamp()


def _timestamp():
    """Returns current timestamp.

    Note: mocked in tests.
    """

    return int(time.time())


def monkeypatch_timestamp(monkeypatch, cur_time=None):
    class TimeMocker:
        def __init__(self, cur_time=None):
            self.cur_time = cur_time if cur_time is not None else int(time.time())

        def __call__(self, *args, **kwargs):
            return self.cur_time

        def bump_time(self, inc=1):
            self.cur_time += inc

    time_mocker = TimeMocker(cur_time)
    import walle.models

    monkeypatch.setattr(walle.models, "_timestamp", time_mocker)

    return time_mocker
