import pymongo
import pymongo.errors
import time
import logging

log = logging.getLogger(__name__)

CACUS = 'cacusdb'
REPOS = 'reposdb'
HISTORY = 'histdb'

DATABASES = [CACUS, REPOS, HISTORY]
DB_RETRIES = 5

connections = {}


class AutoReconnectMaxRetries(Exception):
    pass


class GetAttrWrapper(object):
    def __init__(self, obj):
        self._obj = obj

    def __getattr__(self, name):
        def wrap(fun):
            def retry_wrapper(*args, **kwargs):
                for i in range(DB_RETRIES):
                    try:
                        return fun(*args, **kwargs)
                    except pymongo.errors.AutoReconnect as e:
                        msg = 'Got AutoReconnect error!'
                        msg += ' Error message:\n{}\n'.format(e)
                        msg += 'Sleeping 1 second before next retry.'
                        log.critical(msg)
                        time.sleep(i ** 3)
                raise AutoReconnectMaxRetries()

            return retry_wrapper

        return wrap(getattr(self._obj, name))

    def __str__(self):
        return "GetAttrWrapper for obj: {}".format(str(self._obj))

    def __repr__(self):
        return self.__str__()


class AutoReconnectRetryWrapper(object):
    def __init__(self, obj):
        self._obj = obj

    def __getitem__(self, key):
        return GetAttrWrapper(self._obj[key])

    def __getattr__(self, name):
        if name in filter(lambda x: x[0] != '_', dir(self._obj)):
            return getattr(self._obj, name)
        return GetAttrWrapper(getattr(self._obj, name))

    def collection_names(self, *args, **kwargs):
        raise NotImplementedError('Direct usage of collection_names forbidden.'
                                  ' Use collection_names_reliable instead.')

    def collection_names_reliable(self, *args, **kwargs):
        for i in range(DB_RETRIES):
            try:
                return getattr(self._obj, 'collection_names')(*args, **kwargs)
            except pymongo.errors.AutoReconnect as e:
                msg = 'Got AutoReconnect error!'
                msg += ' Error message:\n{}\n'.format(e)
                msg += 'Sleeping 1 second before next retry.'
                log.critical(msg)
                time.sleep(i ** 3)
        raise AutoReconnectMaxRetries()

    def __str__(self):
        return "AutoReconnectRetryWrapper for obj: {}".format(str(self._obj))

    def __repr__(self):
        return self.__str__()


class MongoDatabaseConnection(object):
    def __init__(self, conn_name, db_name, uri, driver=pymongo.MongoClient, wrapper=AutoReconnectRetryWrapper, config=None):
        self.conn_name = conn_name
        self._db_name = db_name
        self.uri = uri
        self.driver = driver
        self.wrapper = wrapper
        self._conn = None
        self.config = config

    def get(self):
        if not self._conn:
            log.info('Initializing connection: {}'.format(str(self)))
            if self.wrapper:
                self._conn = self.wrapper(self.driver(self.uri)[self._db_name])
            else:
                self._conn = self.driver(self.uri)[self._db_name]
        return self._conn

    def is_active(self):
        return self._conn is not None

    @staticmethod
    def make_mongo_uri(conf):
        if conf.get('username') and conf.get('password'):
            uri_template = 'mongodb://{username}:{password}@{host}:{port}/{db}'
        else:
            uri_template = 'mongodb://{host}:{port}/{db}'
        if conf.get('type') == 'replicaset':
            uri_template += '?replicaSet={replicaset}' \
                            '&readPreference={read_preference}' \
                            '&maxPoolSize={max_pool_size}' \
                            '&waitQueueMultiple={wait_queue_multiple}'
        return uri_template.format(
            username=conf['username'],
            password=conf['password'],
            host=conf['host'],
            port=conf.get('port', 27017),
            db=conf['db'],
            replicaset=conf.get('replicaset'),
            read_preference=conf.get('read_preference', 'primary'),
            max_pool_size=conf.get('max_pool_size', 4),
            wait_queue_multiple=conf.get('wait_queue_multiple', 15)
        )

    @classmethod
    def from_config(cls, conn_name, conf, driver=pymongo.MongoClient, wrapper=AutoReconnectRetryWrapper):
        uri = cls.make_mongo_uri(conf)
        return cls(conn_name, conf['db'], uri, driver, wrapper, conf)

    def __str__(self):
        if '@' in self.uri:
            proto_end = self.uri.find("://") + 3
            host_begin = self.uri.rfind("@")
            clean_str = self.uri[:proto_end] + "****:****" + self.uri[host_begin:]
        else:
            clean_str = self.uri
        return "MongoDatabaseConnection({})".format(clean_str)

    def __repr__(self):
        return self.__str__()


def configure(metadb_conf, driver=pymongo.MongoClient, wrapper=AutoReconnectRetryWrapper):
    for db in DATABASES:
        if db in connections and connections[db].is_active():
            log.warning('Reconfiguring active connection: {}'.format(db))
        log.info('Configuring connection: database: {}, driver: {}, wrapper: {}'.format(
            db, driver.__name__, None if not wrapper else wrapper.__name__))
        connections[db] = MongoDatabaseConnection.from_config(db, metadb_conf[db], driver, wrapper)


def cacus():
    return connections[CACUS].get()


def repos():
    return connections[REPOS].get()


def history():
    return connections[HISTORY].get()


def history_config():
    return connections[HISTORY].config
