# -*- coding: utf-8 -*-

from collections import defaultdict
from threading import current_thread
from types import NoneType
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from psycopg2 import ProgrammingError, OperationalError

from mpfs.engine.process import get_default_log
from mpfs.config import settings
from mpfs.metastorage.postgres.queries import SqlTemplatedQuery
from mpfs.metastorage.postgres.query_executer import PGQueryExecuter, ReadPreference as PGReadPreference, \
    convert_integrity_error
from mpfs.metastorage.postgres.exceptions import EofDetectedError, SetAutocommitError, ConnectionBasedSessionClosed


POSTGRES_SPECIAL_UID_FOR_COMMON_SHARD = settings.postgres['common_uid']
POSTGRES_UID_STATISTICS_ENABLED = settings.postgres['uid_statistics']['enabled']
POSTGRES_UID_STATISTICS_ANONYMOUS_UID = settings.postgres['uid_statistics']['anonymous_uid']

default_log = get_default_log()


class Session(object):
    """Объект, позволяющий выполнять запросы в постгресе.
    Содержит в себе управление логикой получения коннекта к базе по идентификатору постгресового шарда или по uid'у.

    Концептуально, этот объект нужен для того, чтобы можно было его передавать при выполнении операций в DAO.
    При этом можно сделать что - можно создать объект сессии, начать транзакцию, передать его в разные методы типа
    `create` на одном шарде и потом закоммитить транзакцию.

    Для создания объекта предпочтительнее пользоваться методами класса `create_from_uid` и `create_from_shard_id`.

    Примеры использования:
    >>> session = Session.create_from_uid('1234567890', read_preference=PGReadPreference.primary_preferred)
    >>> with session.begin():
    >>>     FileDAO.create(file_item)
    >>>     FolderDAO.create(folder_item)
    """

    _shard_id_cache = defaultdict(dict)  # кеш сессий по shard_id
    _pg_query_executer = PGQueryExecuter()

    def __init__(self, shard_id=None, connection=None, read_preference=PGReadPreference.primary):
        if shard_id is None and connection is None:
            raise RuntimeError('shard id or connection must be specified to create session')

        self._shard_id = shard_id
        self._pg_connection = connection
        self._pg_transactions = []
        self._read_preference = read_preference
        self.ucache_hint_uid = None

    def set_ucache_hint_uid(self, ucache_hint_uid):
        if not isinstance(ucache_hint_uid, (NoneType, int, long)):
            raise TypeError('ucache_hint_uid must be integer, get %s(%s)' % (type(ucache_hint_uid), ucache_hint_uid))
        self.ucache_hint_uid = ucache_hint_uid

    @classmethod
    def create_from_uid(cls, uid, read_preference=PGReadPreference.primary):
        """Создает объект сессии по uid'у (то есть объект будет ассоциирован с шардом, на котором живет пользователь).
        Шард, на котором живет пользователь, получим из шарпея.
        """
        uid = int(uid)  # шарпей работает только с числовыми uid'ами
        shard_id = cls._pg_query_executer.get_shard_id(uid)
        session = cls.create_from_shard_id(shard_id, ucache_hint_uid=uid, read_preference=read_preference)
        return session

    @classmethod
    def create_from_shard_id(cls, shard_id, ucache_hint_uid=None, read_preference=PGReadPreference.primary):
        """Создает объект сессии по идентификатору шарда в шарпее.
        """
        key = cls._get_cache_key(shard_id)
        if key not in cls._shard_id_cache or read_preference not in cls._shard_id_cache[key]:
            cls._shard_id_cache[key][read_preference] = cls(shard_id=shard_id, read_preference=read_preference)
        session = cls._shard_id_cache[key][read_preference]
        session.set_ucache_hint_uid(ucache_hint_uid)
        return session

    @classmethod
    def create_from_shard_endpoint(cls, shard_endpoint, ucache_hint_uid=None, read_preference=PGReadPreference.primary):
        """Создает объект сессии по ShardEndpoint.
        """
        if not shard_endpoint.is_pg():
            raise ValueError('Except PG shard endpoint. Got %s' % shard_endpoint)
        return cls.create_from_shard_id(
            shard_endpoint.get_name(), ucache_hint_uid=ucache_hint_uid, read_preference=read_preference)

    @classmethod
    def create_for_all_shards(
            cls, skip_unavailable_shards=False, read_preference=PGReadPreference.primary, use_threads=False):
        connections = cls._pg_query_executer.get_connection_for_all_shards(
            skip_unavailable_shards,
            read_preference,
            use_threads=use_threads,
        )
        sessions = []
        for conn in connections:
            sessions.append(Session(connection=conn, read_preference=read_preference))
        return sessions

    @classmethod
    def create_common_shard(cls, read_preference=PGReadPreference.primary):
        """Создает объект сессии в общий шард.
        """
        return cls.create_from_uid(POSTGRES_SPECIAL_UID_FOR_COMMON_SHARD, read_preference=read_preference)

    @classmethod
    def clear_cache(cls):
        """Метод не потокобезопасный и предназначен для вызова из главного треда после обработки запроса и завершения
         вспомогательных тредов. В многопоточной среде используете его на свой страх и риск, скорее всего работать будет
         с рейсами и прочими радостями жизни.
        """
        for _, cache in cls._shard_id_cache.iteritems():
            for _, session in cache.iteritems():
                session.close(check_active_transactions=False)
        cls._shard_id_cache = defaultdict(dict)

    def detach_from_cache(self):
        key = self._get_cache_key(self._shard_id)
        read_preference = self._read_preference
        self._shard_id_cache[key].pop(read_preference)

    def close(self, check_active_transactions=True):
        if not check_active_transactions or not self._pg_transactions:
            self._close_connection()

    def _uid_comment(self, kwargs):
        if not POSTGRES_UID_STATISTICS_ENABLED:
            return ''

        if self.ucache_hint_uid is not None:
            uid = self.ucache_hint_uid
        else:
            uid = kwargs.get('uid', POSTGRES_UID_STATISTICS_ANONYMOUS_UID)
        return ' /* uid:%s */ ' % uid

    def execute(self, query, *args, **kwargs):
        if self._pg_transactions:
            if not self._pg_transactions[0].is_active:
                raise RuntimeError('Transaction is started but was rollbacked')

        if isinstance(query, basestring):
            query = text(self._uid_comment(kwargs) + query)
        elif isinstance(query, SqlTemplatedQuery):
            query = text(self._uid_comment(kwargs) + str(query))

        try:
            return self._conn.execute(query, *args, **kwargs)
        except EofDetectedError:
            if not self._pg_transactions:
                # костыль только для ретрая без транзакций
                self._close_connection()
                if not self._is_created_with_raw_connection():
                    return self._conn.execute(query, *args, **kwargs)
            raise

    def execute_queries(self, queries_with_params):
        assert isinstance(queries_with_params, (list, tuple))
        return [self.execute(q.query, **q.params) for q in queries_with_params]

    def execute_and_close(self, query, *args, **kwargs):
        try:
            self.execute(query, *args, **kwargs)
        finally:
            self.close()

    def abort_query_execution(self):
        if self._pg_connection is not None:
            db_api_connection = self._pg_connection.connection.connection
            db_api_connection.cancel()

    def begin(self):
        if not self._pg_transactions:
            try:
                self._set_dbapi_connection_autocommit(False)
            except ProgrammingError as exc:
                if 'autocommit' in exc.message:
                    default_log.info('Failed to set autocommit for new pg transaction (found not finished transaction):'
                                     ' isexecuting=%s; status=%s' % (
                        self._pg_connection.connection.connection.isexecuting(),
                        self._pg_connection.connection.connection.get_transaction_status()
                    ))
                    # Если транзакция еще открыта по какой-то причине, завершаем ее
                    self._pg_connection.connection.connection.rollback()
                    # И снова выставляем параметр autocommit, но у же в дефолтное значение т.к. будем райзить ошибку
                    # и транзакцию начинать не будем
                    self._set_dbapi_connection_autocommit(True)
                    raise SetAutocommitError()
                else:
                    raise

        transaction = self._conn.begin()
        self._pg_transactions.append(transaction)

        return self

    def rollback(self):
        if self._pg_transactions:
            transaction = self._pg_transactions.pop()
            transaction.rollback()
            self._set_dbapi_connection_autocommit(True)

    def commit(self):
        if self._pg_transactions:
            transaction = self._pg_transactions.pop()
            transaction.commit()

        if not self._pg_transactions:
            self._set_dbapi_connection_autocommit(True)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if not self._pg_transactions:
            return

        if exc_type is None:
            try:
                self.commit()
            except IntegrityError as e:
                self.rollback()
                raise convert_integrity_error(e)
            except:
                self.rollback()
                raise
        else:
            self.rollback()
            raise

    def _is_created_with_raw_connection(self):
        return self._shard_id is None

    @property
    def _conn(self):
        if self._pg_connection is None:
            if self._is_created_with_raw_connection():
                raise ConnectionBasedSessionClosed('Session was created with connect and it\'s already closed')
            self._pg_connection = self._pg_query_executer.get_connection_by_shard_id(
                self._shard_id, self._read_preference)
        return self._pg_connection

    def _close_connection(self):
        while self._pg_transactions:
            self.rollback()

        if self._pg_connection is not None:
            self._pg_connection.close()
            self._pg_connection = None

    def _set_dbapi_connection_autocommit(self, enable):
        db_api_connection = self._conn.connection.connection
        db_api_connection.autocommit = enable

    @staticmethod
    def _get_cache_key(param):
        thread_id = current_thread().ident
        return '%s#%s' % (thread_id, param)
