"""Document class is a helper that should ease up migration from MongoEngine to pyMongo."""

import cachetools

from sepelib.core.exceptions import Error, LogicalError
from walle.util.misc import first, drop_none

_FIELD_PATH_TRIE_CACHE = cachetools.LRUCache(maxsize=25)


def _set_dict_path(target, path, value):
    for token in path[:-1]:
        target = target.get(token, {})

    target[path[:-1]] = value


class DocumentDoesNotExist(Error):
    def __init__(self, model):
        super().__init__("{} with requested params does not exist.", model.__name__)


class MongoDocument:
    """Small wrapper that maps pymongo object items to a MongoEngine model fields.
    It stores the actual object as it was received from pymongo and it only converts names on the fly.
    """

    _model = None
    _model_fields = None
    _model_fields_map = None
    _key_field_name = None
    _id_field_name = "_id"
    _api_fields = None
    _default_api_fields = None
    _resolved_fields = None
    _classes_cache = {}

    @classmethod
    def for_model(cls, model):
        name = str("Api{}Document".format(model.__name__))

        try:
            return cls._classes_cache[name]
        except KeyError:
            model_fields = dict(model._fields)
            doc = cls._classes_cache[name] = type(
                name,
                (MongoDocument,),
                dict(
                    _model=model,
                    _model_fields=model_fields,
                    _model_fields_map={field.db_field: field_name for field_name, field in model_fields.items()},
                    _key_field_name=model._meta.get("id_field", None),
                    _api_fields=cls._hash_fields_impl(tuple(sorted(getattr(model, "api_fields", [])))),
                    _default_api_fields=cls._hash_fields_impl(tuple(sorted(getattr(model, "default_api_fields", [])))),
                    _resolved_fields={},
                    **_generate_fields(model_fields)
                ),
            )
            return doc

    def __init__(self, data):
        self._data = data

    @classmethod
    def find(cls, query=None, fields=None, sort=None, read_preference=None, limit=None):
        """Create db cursor with given parameters."""
        if not cls._model:
            raise LogicalError

        collection = cls._model.get_collection(read_preference=read_preference)
        kwargs = drop_none(dict(sort=sort, limit=limit))

        return (cls(o) for o in collection.find(query, fields, **kwargs))

    @classmethod
    def find_and_modify(cls, query=None, update=None, sort=None, fields=None, read_preference=None):
        if not cls._model:
            raise LogicalError

        if not query:
            query = {}

        collection = cls._model.get_collection(read_preference=read_preference)
        kwargs = drop_none(dict(sort=sort))

        res = collection.find_and_modify(query=query, update=update, fields=fields, **kwargs)
        if res:
            return cls(res)

    @classmethod
    def find_one(cls, query=None, fields=None, read_preference=None, silent=False, sort=None):
        """Fetch first matched document. Raise"""

        res = first(cls.find(query, fields, read_preference=read_preference, sort=sort))
        if res is None and not silent:
            raise DocumentDoesNotExist(cls._model)

        return res

    @classmethod
    def key_field(cls):
        return cls._key_field_name

    def fields(self):
        for db_field, value in self._data.items():
            if db_field in self._model_fields_map:
                yield self._model_fields_map[db_field], value

    @classmethod
    def _get_db_field_name(cls, field_name, model=False):
        model_field = cls._model_fields[field_name]
        if model:
            return model_field.db_field, getattr(model_field, "document_type", None)
        else:
            return model_field.db_field

    @classmethod
    def _all_field_paths(cls, field_name, narrowers=None):
        try:
            model_field = cls._model_fields[field_name]
        except KeyError:
            return

        model = getattr(model_field, "document_type", None)
        if model is None:
            if not narrowers:
                yield (field_name,)
        else:
            doc_wrapper = cls.for_model(model)
            if narrowers:
                child_field, narrowers = narrowers[0], narrowers[1:]
                for child_path in doc_wrapper._all_field_paths(child_field):
                    yield (field_name,) + child_path
            else:
                for child_field in doc_wrapper._model_fields.keys():
                    for child_path in doc_wrapper._all_field_paths(child_field):
                        yield (field_name,) + child_path

    @classmethod
    def fields_to_api_fields(cls, fields):
        """Convert list of fields as requested by user into list of fields that are explicitly allowed by API.
        e.g. if only host.task.status is allowed, then [host.task] will be converted into [host.task.status]
        """
        for f in fields:
            requested_path = f.split(".")
            f, narrowers = requested_path[0], requested_path[1:]

            for path in cls._all_field_paths(f, narrowers):
                if path in cls._api_fields:
                    yield ".".join(path)

    @classmethod
    def fields_to_query_fields(cls, fields):
        """Convert list of fields as requested by user into list of fields that are explicitly allowed by API.
        e.g. if only host.task.status is allowed, then [host.task] will be converted into [host.task.status]
        Fields are resolved into a db-form, e.g. these fields should be used to query database with pymongo,
        which means that [host.inv] will be replaced with [host._id].
        """
        return [cls.resolve_field(field) for field in cls.fields_to_api_fields(fields)]

    @classmethod
    def resolve_field(cls, field_name):
        """Take mongoengine dotted field and convert it into mongo db real dotted field."""
        # NB: There are no workarounds or special cases yet, but you are free to add any as you come across.
        if field_name not in cls._resolved_fields:
            db_field_path = []
            doc_wrapper = cls
            for field in field_name.split("."):
                db_field, model = doc_wrapper._get_db_field_name(field, model=True)
                db_field_path.append(db_field)
                if model:
                    doc_wrapper = cls.for_model(model)
                else:
                    doc_wrapper = None
            cls._resolved_fields[field_name] = ".".join(db_field_path)

        return cls._resolved_fields[field_name]

    def reload(self, *fields):
        """Fetch requested fields from the database and refresh values on the instanse."""
        key_field_name = self._model._meta["id_field"]
        fields = [self._get_db_field_name(f) for f in fields]

        try:
            self._data = self._model.get_collection().find_one(self._data[key_field_name], fields or None)
        except StopIteration:
            raise Error("{} {} is missing from database.", self._model.__name__, self._data[key_field_name])

        return self

    def to_api_obj(self, requested_fields=None, extra_fields=None):
        """remove fields that need to be removed, add extra fields and rename _id into proper field name."""
        if requested_fields is None:
            requested_fields = self._default_api_fields
        elif not isinstance(requested_fields, _FieldPathTrie):
            requested_fields = self.hash_fields(requested_fields)

        api_obj = self._filter_fields_deep(self._data, requested_fields)

        if extra_fields:
            self._add_extra_fields(api_obj, extra_fields)

        return api_obj

    def to_api_object_shallow(self, requested_fields=None, extra_fields=None):
        """This is a fast but not safe variant of to_api_obj:
        only safe if you use `fields_to_query_fields` to fetch objects."""

        if requested_fields is None:
            requested_fields = self._default_api_fields.shallow()
        elif isinstance(requested_fields, _FieldPathTrie):
            requested_fields = requested_fields.shallow()
        else:
            requested_fields = self.hash_fields(requested_fields).shallow()

        api_obj = {key: value for key, value in self._data.items() if key in requested_fields}

        if self._id_field_name in self._data and self._key_field_name in requested_fields:
            api_obj[self._key_field_name] = self._data[self._id_field_name]

        if extra_fields:
            self._add_extra_fields(api_obj, extra_fields)

        return api_obj

    @classmethod
    def hash_fields(cls, fields):
        fields = tuple(sorted(cls.fields_to_api_fields(fields)))
        return cls._hash_fields_impl(fields)

    @staticmethod
    def _hash_fields_impl(fields):
        trie = _FIELD_PATH_TRIE_CACHE.get(fields, None)

        if trie is None:
            trie = _FieldPathTrie(fields)
            _FIELD_PATH_TRIE_CACHE[fields] = trie

        return trie

    @classmethod
    def _filter_fields_deep(cls, data, requested_fields, path=()):
        # speed-up: no renaming except for _id fields. copy only fields that are allowed.
        # we do not use this feature with db fields having different names than object fields.
        if cls._id_field_name in data:
            data[cls._key_field_name] = data.pop(cls._id_field_name)

        result = {}
        for field, value in data.items():
            field_path = path + (field,)

            if isinstance(value, dict):
                renamed = cls._filter_fields_deep(value, requested_fields, field_path)
                # don't check allowed fields for branches, only for leafs
                if renamed:
                    result[field] = renamed

            # filter out empty iterables
            elif (not isinstance(value, (tuple, list, dict)) or value) and field_path in requested_fields:
                result[field] = value

        return result

    @classmethod
    def _add_extra_fields(cls, data, extra):
        for field, value in extra.items():
            # extra fields allow to remove field value by replacing in with None.
            if value is None:
                data.pop(field, None)
            else:
                data[field] = value

    def __getitem__(self, item):
        return self._data[item]


def _generate_fields(model_fields):
    fields = {}
    for field_name, model_field in model_fields.items():
        db_field = model_field.db_field
        embedded_doc_field_model = getattr(model_field, "document_type", None)

        if embedded_doc_field_model:
            _field_model = MongoDocument.for_model(embedded_doc_field_model)

            def get_field_value(self, _db_field=db_field, _field_model=_field_model):
                value = self._data.get(_db_field)
                if isinstance(value, dict):
                    return _field_model(value)
                else:
                    return value

        else:

            def get_field_value(self, _db_field=db_field):
                return self._data.get(_db_field)

        fields[field_name] = property(get_field_value)

    return fields


class _FieldPathTrie:
    """Use the modified "prefix trie" algorithm to check if requested field is allowed to be sent to the user.
    Modification to the original algorithm is: instead of storing characters in the nodes,
    this algorithm stores whole model field names (path segments),
    because those segments are the minimum chunks of a prefix, not the characters.
    This algorithm also does not use canonical way of hashing node values,
    instead, it stores trie in a standard python dict, which makes it's implementation much more transparent
    and which is fast enough for our purpose.
    """

    _STOP_MARKER = "."

    def __init__(self, fields):
        stop = self._STOP_MARKER
        trie = dict()
        for f in fields:
            branch = trie
            for token in f.split(stop):
                branch = branch.setdefault(token, {})
            branch[stop] = stop

        self._trie = trie

    def shallow(self):
        # produce top level fields for shallow filtering, see `MongoDocument.to_api_object_shallow`
        return set(self._trie)

    def __contains__(self, field_path):
        stop = self._STOP_MARKER

        branch = self._trie
        for token in field_path:
            if token in branch:
                branch = branch[token]
            elif stop in branch:
                # prefix of requested field is allowed
                return True
            else:
                # some other field with current prefix is allowed, but not this field.
                return False

        if stop in branch:
            # this field is allowed explicitly
            return True
        else:
            # this is a weird case: field is not a final path to a value, but is a prefix itself.
            # this may happen if, for example, some garbage has been received in requested_fields.
            return False
