from pymongo import MongoClient
from pymongo.read_preferences import ReadPreference

from django.conf import settings


class MongoConnection(object):
    def __init__(self):
        self._db = None

    @property
    def db(self):
        if not self._db:
            mongo_uri_template = (
                'mongodb://{login}:{password}@{hosts}/{db}'
                '?replicaSet={rs_name}'
                '&connectTimeoutMS={connect_timeout}'
                '&read_preference={read_preference}'
            )
            mongo_connection = MongoClient(
                mongo_uri_template.format(
                    login=settings.MONGO_USER,
                    password=settings.MONGO_PASSWORD,
                    hosts=settings.MONGO_HOSTS,
                    db=settings.MONGO_DATABASE,
                    rs_name=settings.MONGO_SET_NAME,
                    connect_timeout=settings.MONGO_TIMEOUT,
                    read_preference=ReadPreference.PRIMARY_PREFERRED.mongos_mode,
                )
            )

            self._db = mongo_connection[settings.MONGO_DATABASE]

        return self._db


mongo = MongoConnection()


class CollectionObjectGenerator(object):

    def __init__(self, collection, q, wraper=None):
        self.q = q
        self.collection = collection
        self._sort = None
        self.wraper = wraper

    def iterator(self, skip=0, limit=0):
        result = self.collection.find(
            self.q,
            sort=self._sort,
            skip=skip,
            limit=limit
        )
        if self.wraper:
            result = (self.wraper(obj) for obj in result)
        return result

    def __getitem__(self, key):

        if isinstance(key, slice):

            if ((key.start is None and key.stop is None) or
                    key.start < 0 or
                    key.stop < key.start):
                raise KeyError()

            result = self.iterator(
                skip=key.start,
                limit=(key.stop - key.start if key.stop else 0)
            )

        elif isinstance(key, int):
            if key <= 0:
                raise KeyError()

            result = self.iterator(skip=key, limit=1)

        else:
            raise KeyError()

        return result

    def __iter__(self):
        return self.iterator()

    def __len__(self):
        return self.collection.find(self.q).count()

    def sort(self, key):
        self._sort = key
        return self
