import copy
import inspect
import functools
import itertools as it
import threading
import traceback

import six
from bson import json_util

import mongoengine
import mongoengine.queryset
import mongoengine.document
import mongoengine.connection

import sandbox.common.types.database as ctd

tls = threading.local()


class QuerySet(mongoengine.queryset.QuerySet):

    def __init__(self, *args, **kwargs):
        super(QuerySet, self).__init__(*args, **kwargs)
        self._fast_scalar = None
        self._lite = None

    @staticmethod
    def _attr_lookup(item, path):
        """
        Fetch dot-separated key path from a dictionary.
        Return None if at least one key on the path does not exist.
        """
        try:
            for key in path.split('.'):
                item = item[key]
            return item
        except KeyError:
            return None

    def _clone_into(self, new_qs):
        """
        Make sure `._fast_scalar` survives queryset cloning
        """
        new_qs = super(QuerySet, self)._clone_into(new_qs)
        setattr(new_qs, "_fast_scalar", copy.copy(self._fast_scalar))
        setattr(new_qs, "_lite", copy.copy(self._lite))
        return new_qs

    def limit(self, n):
        """Limit the number of returned documents to `n`. This may also be
        achieved using array-slicing syntax (e.g. ``User.objects[:5]``).

        :param n: the maximum number of objects to return
        """
        queryset = self.clone()
        queryset._limit = n

        # If a cursor object has already been created, apply the limit to it.
        if queryset._cursor_obj:
            queryset._cursor_obj.limit(queryset._limit)

        return queryset

    def count(self, with_limit_and_skip=False):
        # Copied from an old version of Mongoengine
        # for `.limit(0).count()` to still return the total number of objects in the collection.

        if self._limit == 0 and with_limit_and_skip or self._none:
            return 0
        return self._cursor.count(with_limit_and_skip=with_limit_and_skip)

    def fast_scalar(self, *fields):
        """Faster drop-in replacement of .scalar() using raw pymongo values.
        Suitable for primitive datatypes only; embedded objects are returned as raw dictionaries.
        """
        if not fields:
            raise ValueError("`fields` cannot be empty")

        queryset = self().only(*fields).as_pymongo()
        queryset._fast_scalar = self._fields_to_dbfields(f.replace('__', '.') for f in fields)
        return queryset

    def lite(self):
        queryset = self().as_pymongo()
        queryset._lite = True
        return queryset

    def _doc_to_tuple(self, doc):
        fields = self._fast_scalar
        if len(fields) == 1:
            return self._attr_lookup(doc, fields[0])
        return tuple(self._attr_lookup(doc, f) for f in fields)

    def __getitem__(self, key):
        result = super(QuerySet, self).__getitem__(key)
        if result is not None:
            if self._lite:
                doc_class = LiteDocument.generate_field_lite_class(self._document)
                return doc_class(bson=result)
            if isinstance(key, int) and self._fast_scalar:
                return self._doc_to_tuple(result)
        return result

    def next(self):
        doc = super(QuerySet, self).next()
        if doc is not None and self._lite:
            doc_class = LiteDocument.generate_field_lite_class(self._document)
            return doc_class(bson=doc)
        return self._doc_to_tuple(doc) if self._fast_scalar else doc


class QuerySetManager(mongoengine.queryset.QuerySetManager):
    default = QuerySet


def queryset_manager(func):
    return QuerySetManager(func)


class ConnectionSwitcherMixin(object):
    """
    Mongoengine document mixin
    @DynamicAttrs
    """
    @classmethod
    def _get_collection(cls):
        """Returns the collection for the document."""
        try:
            rp = tls.request.read_preference
        except AttributeError:
            rp = ctd.ReadPreference.PRIMARY
        rp2coll = getattr(cls, '_rp2collection', {})
        coll = rp2coll.get(rp)
        if not coll:
            coll = mongoengine.connection.get_db(rp)[cls._get_collection_name()]
            cls._rp2collection = rp2coll
            rp2coll[rp] = coll
        return coll

    @queryset_manager
    def objects(doc_cls, queryset):
        try:
            req_id = tls.request.id[:8]
        except AttributeError:
            req_id = None
        try:
            user = tls.request.user.login
        except AttributeError:
            user = None
        try:
            # mongoengine expects integer timeout
            timeout = int(tls.request.timeout)
        except Exception:
            timeout = None

        try:
            frame = traceback.extract_stack()[-3]  # Last two frames are always the same
        except KeyError:  # Sometimes extract_stack raises KeyError.
            caller = "undefined"
        else:
            caller = "{}:{}".format(frame[0], frame[1])

            # Strip common prefix to make comment message shorter
            pos = caller.rfind("/sandbox/")
            if pos > 0:
                caller = caller[pos + 1:]

        # Request-Id should be always the first,
        # because comment can be stripped sometimes (e.g. mongodb logs).
        values = [
            ("Reqid", req_id),
            ("User", user),
            ("Caller", caller),
        ]
        result = queryset.comment(";".join("{}={}".format(k, v) for k, v in values if v is not None))
        if timeout:
            result = result.max_time_ms(1000 * timeout)
        return result


#: Reference to object ID field type.
ReferenceField = functools.partial(mongoengine.IntField, min_value=1)


class Aggregatable(object):
    """
    The mixin class is created just to provide additional method "aggregate".
    Since `mongoengine` does not allow to inherit from `mongoengine.Document` class,
    this will mixin `aggregate` method to classes, which like to utilize aggregation.
    """

    # This should come from `mongoengine.Document` class.
    _get_collection = None

    @classmethod
    def aggregate(cls, pipeline, **kws):
        """
        The method just passes given pipeline directly to pymongo's `collection.aggregate()` method.
        It also check result for error and raises `mongoengine.OperationError` exception in case
        of aggregation was finished with error(s).
        :param pipeline:    MongoDB aggregation pipeline.
        :param kws:         Optional aggregate parameters should be passed as keyword arguments.
        :return:            Aggregation result.
        """
        return list(cls._get_collection().aggregate(pipeline, **kws))


class LiteDocument(object):
    def __init__(self, *args, **kwargs):
        self._bson = kwargs.pop("bson", {})
        self._changed_fields = {}
        self._cache = {}
        document_type = self.__document_type__
        self._pk = getattr(document_type, "id", None)
        if self._pk is not None:
            self._pk = document_type.id.name
        for name, value in six.iteritems(kwargs):
            k = document_type._db_field_map[name]
            v = value

            to_mongo = getattr(value, "to_mongo", None)
            if to_mongo is not None:
                v = to_mongo(v)
            self._bson[k] = v

    @classmethod
    def _from_son(cls, bson, _auto_dereference=True, only_fields=None, created=False):
        new_class = cls(bson=bson)
        return new_class

    def to_mongo(self):
        return self._bson

    def to_json(self):
        return json_util.dumps(self.to_mongo())

    @classmethod
    def from_json(cls, json_data, created=False):
        return cls._from_son(json_util.loads(json_data), created=created)

    @property
    def pk_value(self):
        doc_type = type(self).__document_type__
        return self._bson[doc_type._db_field_map[self._pk]]

    def reload(self):
        doc_type = type(self).__document_type__
        self._bson = doc_type.objects.as_pymongo().with_id(self.pk_value)
        self._changed_fields = {}
        self._cache = {}

    def delete(self):
        doc_type = type(self).__document_type__
        doc_type.objects(**{self._pk: self.pk_value}).delete()

    def __repr__(self):
        return "<{}> | {}".format(type(self), self._bson)

    def __str__(self):
        return "<{}> | {}".format(type(self), self._bson)

    def _get_changed_fields(self):
        changed_fields = []
        stack = []
        stack.append(("", self._changed_fields))

        while stack:
            name, element = stack.pop()
            if isinstance(element, dict):
                for k, v in six.iteritems(element):
                    stack.append((".".join((name, k)), v))
            else:
                if name:
                    changed_fields.append(name[1:])
        return changed_fields

    def _update_bson(self, key, name, value):
        self._changed_fields[name] = True
        self._bson[name] = value
        if value is None:
            self._cache[key] = None
        else:
            self._cache.pop(key, None)

    def __setattr__(self, key, value):
        doc_type = type(self).__document_type__
        if key == "id":
            key = self._pk or "id"
        name = doc_type._db_field_map.get(key)

        if name is None:
            super(LiteDocument, self).__setattr__(key, value)
            return

        field_class = doc_type._fields[key]
        if value is None:
            self._update_bson(key, name, value)
            return

        if isinstance(field_class, mongoengine.BinaryField):
            self._update_bson(key, name, mongoengine.BinaryField().to_mongo(value))
            return

        if isinstance(field_class, mongoengine.EmbeddedDocumentField):
            self._update_bson(key, name, value.to_mongo())
            return

        if (
            isinstance(field_class, mongoengine.ListField) and
            isinstance(field_class.field, mongoengine.EmbeddedDocumentField)
        ):
            self._update_bson(key, name, [_.to_mongo() for _ in value])
            return

        if (
            isinstance(field_class, mongoengine.MapField) and
            isinstance(field_class.field, mongoengine.EmbeddedDocumentField)
        ):
            self._update_bson(key, name, {k: v.to_mongo() for k, v in six.iteritems(value)})
            return

        if isinstance(field_class, mongoengine.ReferenceField):
            self._update_bson(key, name, value.id)

        self._update_bson(key, name, value)

    def update_cache_value(self, item, key, value, embedded=False):
        self._cache[item] = value
        if embedded and key not in self._changed_fields:
            self._changed_fields[key] = value._changed_fields
        return value

    def get_bson_value(self, item):
        key = type(self).__document_type__._db_field_map.get(item)
        if key:
            return self._bson.get(key)
        return None

    def __getattribute__(self, item):
        cache = object.__getattribute__(self, "_cache")
        if item in cache:
            return cache[item]

        doc_type = type(self).__document_type__
        bson = object.__getattribute__(self, "_bson")
        if item == "id":
            item = object.__getattribute__(self, "_pk") or "id"
        key = doc_type._db_field_map.get(item)
        if key is None:
            return object.__getattribute__(self, item)

        field_class = doc_type._fields[item]
        if key not in bson:
            default = field_class.default

            if isinstance(default, mongoengine.EmbeddedDocumentField):
                doc_lite_class = self.generate_field_lite_class(default.document_type)
                return self.update_cache_value(item, key, doc_lite_class._from_son({}), embedded=True)

            if callable(default):
                return self.update_cache_value(item, key, default())

            return self.update_cache_value(item, key, default)

        value = bson[key]

        if isinstance(field_class, mongoengine.EmbeddedDocumentField):
            doc_lite_class = self.generate_field_lite_class(field_class.document_type)
            return self.update_cache_value(item, key, doc_lite_class._from_son(value), embedded=True)

        if isinstance(field_class, mongoengine.ListField):
            list_field_class = field_class.field
            if isinstance(list_field_class, mongoengine.EmbeddedDocumentField):
                doc_lite_class = self.generate_field_lite_class(list_field_class.document_type)
                return self.update_cache_value(item, key, list(map(doc_lite_class._from_son, value)))
            else:
                return self.update_cache_value(item, key, value)

        if isinstance(field_class, mongoengine.MapField):
            list_field_class = field_class.field
            if isinstance(list_field_class, mongoengine.EmbeddedDocumentField):
                doc_lite_class = self.generate_field_lite_class(list_field_class.document_type)
                return self.update_cache_value(
                    item, key, {k: doc_lite_class._from_son(v) for k, v in six.iteritems(value)}
                )
            else:
                return self.update_cache_value(item, key, value)

        if isinstance(field_class, mongoengine.ReferenceField):
            return self.update_cache_value(item, key, field_class.document_type.objects.lite().with_id(value))

        return self.update_cache_value(item, key, field_class.to_python(value))

    def __hasattr__(self, item):
        return item in type(self).__document_type__._db_field_map

    def __len__(self):
        return len(type(self).__document_type__._db_field_map)

    @classmethod
    def generate_field_lite_class(cls, field_class):
        doc_lite_class = getattr(field_class, "__lite_class__", None)
        if doc_lite_class is not None and field_class is doc_lite_class.__document_type__:
            return doc_lite_class

        overridden = set(it.chain(dir(cls), dir(mongoengine.Document), dir(ConnectionSwitcherMixin)))
        namespace_methods = set(dir(field_class)) - overridden

        namespace = {
            name: cls.convert_field(getattr(field_class, name))
            for name in namespace_methods
            if cls.need_to_copy(getattr(field_class, name))
        }
        namespace["__document_type__"] = field_class
        doc_lite_class = field_class.__lite_class__ = type("Lite" + str(field_class), (cls, ), namespace)
        return doc_lite_class

    @staticmethod
    def need_to_copy(item):
        return inspect.ismethod(item) or isinstance(item, property) or inspect.isclass(item)

    @classmethod
    def convert_field(cls, item):
        if inspect.ismethod(item):
            return item.__func__
        else:
            return item
