"""Various MongoDB utilities"""

from __future__ import unicode_literals

import ssl
import pymongo
import pymongo.errors
import mongoengine.connection
import six

from mongoengine import OperationError, NotUniqueError, InvalidDocumentError, InvalidQueryError, BooleanField, QuerySet
from mongoengine.fields import ComplexBaseField, StringField
from mongoengine.connection import get_db
from mongoengine.document import Document, BaseDocument
from mongoengine.queryset import transform
from mongoengine.queryset.base import BaseQuerySet
from mongoengine.common import _import_class

from sepelib.core import config
from sepelib.core.exceptions import Error

_MODELS = []
_PATCHED = False
IS_PYMONGO_2 = int(pymongo.version.split('.')[0]) < 3


def register_model(cls):
    """Decorator for registering document models in the library."""

    _MODELS.append(cls)
    return cls


def get_registered_models():
    """Returns a list of all registered models."""

    return _MODELS


def register_database(config_name, alias=mongoengine.connection.DEFAULT_CONNECTION_NAME,
                      use_greenlets=True, tz_aware=False,
                      read_preference=pymongo.ReadPreference.PRIMARY_PREFERRED, connect=True):
    """Registers MongoEngine database according to the specified configuration."""

    uri = config.get_value(config_name + ".uri")
    name = config.get_value(config_name + ".db")
    ssl_ca_certs = config.get_value(config_name + ".ssl_ca_certs", None)

    # How long a connection can take to be opened before timing out
    connect_timeout = config.get_value(config_name + ".connect_timeout", 3)

    # How long a send or receive on a socket can take before timing out
    socket_timeout = config.get_value(config_name + ".socket_timeout", 30)

    # PyMongo max_pool_size tuning, defaults to 100
    max_pool_size = config.get_value(config_name + ".max_pool_size", 100)

    try:
        connect_options = {"host": uri, "name": name, "connectTimeoutMS": connect_timeout * 1000,
                           "socketTimeoutMS": socket_timeout * 1000, "alias": alias, "tz_aware": tz_aware,
                           "read_preference": read_preference}

        if ssl_ca_certs is not None:
            connect_options["ssl_ca_certs"] = ssl_ca_certs
            connect_options["ssl_cert_reqs"] = ssl.CERT_REQUIRED

        if IS_PYMONGO_2:
            connect_options["max_pool_size"] = max_pool_size
            connect_options["use_greenlets"] = use_greenlets
        else:
            connect_options["maxPoolSize"] = max_pool_size

        mongoengine.connection.register_connection(**connect_options)

        if connect:
            mongoengine.connection.get_connection(alias=alias)
    except Exception as e:
        raise Error("Unable to connect to MongoDB ({}): {}", uri, six.text_type(e).replace("\n", " "))


def ensure_all_indexes():
    """Creates indexes for all registered models.

    Every MongoEngine model should be registered with register_model() and ensure_all_indexes() should be called on
    service deployment or startup.
    """

    for cls in _MODELS:
        cls.ensure_indexes()


def is_patched():
    """Returns True if :func:`patch` was called."""
    return _PATCHED


def patch():
    """Patches MongoEngine to:
       * support find_and_modify() command;
       * allow using sets as a field value.
    """

    global _PATCHED

    if _PATCHED:
        return

    if not hasattr(BaseQuerySet, "modify"):
        BaseQuerySet.modify = _modify

    if not hasattr(Document, "modify"):
        Document.modify = _modify_document

    BaseDocument._clear_changed_fields = _clear_changed_fields

    _PATCHED = True


# method to be monkeypatched into BaseDocument to allow using a set as a field value
# upstream implementation uses indexing instead of iterations
# maybe this can bring some speedups too
def _clear_changed_fields(self):
    self._changed_fields = []
    EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
    for field_name, field in six.iteritems(self._fields):
        if (isinstance(field, ComplexBaseField) and isinstance(field.field, EmbeddedDocumentField)):
            field_value = getattr(self, field_name, None)
            if field_value:
                if isinstance(field_value, dict):
                    iterator = six.itervalues(field_value)
                else:  # run over elements
                    iterator = iter(field_value)
                for element in iterator:
                    element._clear_changed_fields()
        elif isinstance(field, EmbeddedDocumentField):
            field_value = getattr(self, field_name, None)
            if field_value:
                field_value._clear_changed_fields()


def _modify(self, upsert=False, sort=None, full_response=False, remove=False, new=False, **update):
    """PyMongo's find_and_modify()."""

    if remove and new:
        raise OperationError("Conflicting parameters: remove and new")

    if not update and not upsert and not remove:
        raise OperationError("No update parameters, must either update or remove")

    queryset = self.clone()
    query = queryset._query
    update = transform.update(queryset._document, **update)

    try:
        result = queryset._collection.find_and_modify(
            query, update, upsert=upsert, sort=sort, remove=remove, new=new,
            full_response=full_response, **self._cursor_args)
    except pymongo.errors.DuplicateKeyError as e:
        raise NotUniqueError("find_and_modify() failed ({})".format(e))
    except pymongo.errors.OperationFailure as e:
        raise OperationError("find_and_modify() failed ({})".format(e))

    if full_response:
        if result["value"] is not None:
            result["value"] = self._document._from_son(result["value"])
    else:
        if result is not None:
            result = self._document._from_son(result)

    return result


def _modify_document(self, query={}, **update):
    """Perform an atomic update of the document in the database and reload the document object using updated version.

    Returns True if the document has been updated or False if the document in the database doesn't match the query.

    .. note:: All unsaved changes that has been made to the document are rejected if the method returns True.

    :param query: the update will be performed only if the document in the database matches the query
    :param update: Django-style update keyword arguments
    """

    if self.pk is None:
        raise InvalidDocumentError("The document does not have a primary key.")

    id_field = self._meta["id_field"]
    query = query.copy() if isinstance(query, dict) else query.to_query(self)

    if id_field not in query:
        query[id_field] = self.pk
    elif query[id_field] != self.pk:
        raise InvalidQueryError("Invalid document modify query: it must modify only this document.")

    updated = self._qs(**query).modify(new=True, **update)
    if updated is None:
        return False

    for field in self._fields_ordered:
        setattr(self, field, self._reload(field, updated[field]))

    self._changed_fields = updated._changed_fields
    self._created = False

    return True


class UndeletedQuerySet(QuerySet):
    def __init__(self, document, collection):
        super(UndeletedQuerySet, self).__init__(document, collection)
        self._initial_query['is_deleted'] = False


class SoftDeleteMixin(object):
    """Adds the ability to mark a document deleted and exclude it
    from all the queries without actual deleting it.
    """
    meta = {
        'queryset_class': UndeletedQuerySet,
    }

    is_deleted = BooleanField(required=True, default=False)

    def soft_delete(self):
        """Marks a document as deleted."""
        self.is_deleted = True


class _SequentialIdField(StringField):
    # a field to be used as a string id.
    #
    # primary key is required and model validation will not pass if
    # the field doesn't know how to generate itself.
    #
    # so we pretend that we can generate it on the fly
    _auto_gen = True

    UNSET = object()

    def generate(self):
        # but actually we can't, because we don't have a revision
        # instance here. we delegate id generation to Revision.to_mongo() method
        # return UNSET instead of None to prevent MongoEngine from
        # generating a new id
        return self.UNSET


class SequentialIdMixin(object):
    id = _SequentialIdField(primary_key=True)

    def get_id_discriminator(self):
        raise NotImplementedError

    @classmethod
    def _generate_next_id(cls, prefix):
        collection = cls._meta['collection']
        counters_collection_name = '{}.counters'.format(collection)
        counters_collection = get_db()[counters_collection_name]
        sequence_id = '{}.{}'.format(collection, prefix)
        counter = counters_collection.find_and_modify(
            query={'_id': sequence_id},
            update={'$inc': {'next': 1}},
            new=True,
            upsert=True)
        return int(counter['next'])

    def to_mongo(self):
        data = super(SequentialIdMixin, self).to_mongo()
        # see comments to SequentialIdField
        if data['_id'] is _SequentialIdField.UNSET:
            data['_id'] = data.get('id', None)
        if data['_id'] is None:
            discriminator = self.get_id_discriminator()
            data['_id'] = '{}{}'.format(discriminator, self._generate_next_id(discriminator))
        return data
