# -*- coding: utf-8 -*-
import threading

from multiprocessing.pool import ThreadPool

import psycopg2
import psycopg2.extensions
import enum

from sqlalchemy import create_engine, event
from sqlalchemy.exc import DatabaseError, InternalError, ResourceClosedError, IntegrityError, OperationalError
from sqlalchemy.engine.base import Connection as _Connection
from contextlib import contextmanager

import mpfs.engine.process

from mpfs.common.util import use_context
from mpfs.config import settings
from mpfs.dao.shard_endpoint import ShardEndpoint, ShardType
from mpfs.engine.process import get_requests_postgres_log
from mpfs.metastorage.postgres.exceptions import (ReadOnlyDatabaseError, ConnectionClosedError,
                                                  DatabaseConstraintError, UniqueConstraintViolationError,
                                                  QueryCanceledError, EofDetectedError)
from mpfs.metastorage.postgres.services import Sharpei, MasterNotFoundError, SlavesNotFoundError, \
    SharpeiUserNotFoundError
from mpfs.metastorage.postgres.logging import before_cursor_execute, receive_after_execute, PostgresLoggingConnection


POSTGRES_RECONNECTION_ATTEMPTS = settings.postgres['reconnection_attempts']
POSTGRES_POOL_SIZE = settings.postgres['pool_size']
POSTGRES_POOL_OVERFLOW_SIZE = settings.postgres['pool_overflow_size']
POSTGRES_USE_THREADPOOL_FOR_CONNECTING_THRESHOLD = settings.postgres['use_threadpool_for_connecting_threshold']
POSTGRES_CONNECTION_ARGS_CONNECTION_TIMEOUT_SEC = settings.postgres['connection_args']['connection_timeout_sec']
POSTGRES_CONNECTION_ARGS_KEEPALIVE = settings.postgres['connection_args']['keepalives']
POSTGRES_CONNECTION_ARGS_KEEPALIVE_IDLE = settings.postgres['connection_args']['keepalives_idle']
POSTGRES_CONNECTION_ARGS_KEEPALIVE_INTERVAL = settings.postgres['connection_args']['keepalives_interval']
POSTGRES_CONNECTION_ARGS_KEEPALIVE_COUNT = settings.postgres['connection_args']['keepalives_count']
POSTGRES_SHARPEI_REPLICA_FORBIDDEN_DATACENTERS = settings.postgres['sharpei']['replica_forbidden_datacenters']
POSTGRES_SHARPEI_REPLICATION_LAG_THESHOLD = settings.postgres['sharpei']['replication_lag_theshold']


log = mpfs.engine.process.get_default_log()


class Connection(_Connection):
    """
    :type _connection: sqlalchemy.engine.base.Connection
    """
    EOF_DETECTED_ERROR_MESSAGE = 'SSL SYSCALL error: EOF detected'

    _pg_internal_code_map = {
        '25006': ReadOnlyDatabaseError,
    }

    def execute(self, *args, **kwargs):
        try:
            return super(Connection, self).execute(*args, **kwargs)
        except InternalError as e:
            original_exception = e.orig
            if isinstance(original_exception, psycopg2.InternalError):
                exc_cls = self._pg_internal_code_map.get(original_exception.pgcode)
                if exc_cls is not None:
                    raise exc_cls()
            raise
        except ResourceClosedError:
            raise ConnectionClosedError()
        except IntegrityError as e:
            raise convert_integrity_error(e)
        except OperationalError as e:
            original_exception = e.orig
            if isinstance(original_exception, psycopg2.extensions.QueryCanceledError):
                raise QueryCanceledError(e.message)
            error_message = self._get_original_exception_message(e)
            if error_message == self.EOF_DETECTED_ERROR_MESSAGE:
                raise EofDetectedError(e.message)
            raise
        except DatabaseError as e:
            error_message = self._get_original_exception_message(e)
            if error_message == self.EOF_DETECTED_ERROR_MESSAGE:
                raise EofDetectedError(e.message)
            raise

    @staticmethod
    def _get_original_exception_message(psycopg_exc):
        original_exception = psycopg_exc.orig
        error_message = original_exception.message
        if error_message and isinstance(error_message, basestring):
            error_message = error_message.strip()
        return error_message


def convert_integrity_error(exception):
    original_exception = exception.orig
    if original_exception.pgcode == '23505':  # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
        converted_error = UniqueConstraintViolationError()
    else:
        converted_error = DatabaseConstraintError()
    converted_error.message = '%s; %s; %s' % (exception.message, exception.statement, exception.params)
    return converted_error


def _connection_factory(*args, **kwargs):
    connection = PostgresLoggingConnection(*args, **kwargs)
    connection.initialize(get_requests_postgres_log())
    return connection


@contextmanager
def manual_route(shard_name):
    """
    Сделано исключительно для мигратора, не предназначено для использования в коде.
    Если в коде надо идти на определенный шард - надо написать отдельный метод в DAO классе.
    """
    try:
        PGQueryExecuter().set_manual_route(shard_name)
        yield
    finally:
        PGQueryExecuter().drop_manual_route()


class ReadPreference(enum.Enum):
    primary = 'PRIMARY'
    primary_preferred = 'PRIMARY_PREFERRED'
    secondary = 'SECONDARY'
    secondary_preferred = 'SECONDARY_PREFERRED'


class PGQueryExecuter(object):
    """
    :type _sharpei: Sharpei
    :type _engines_cache: dict[str, sqlalchemy.engine.Engine]
    """

    _engines_cache = {}  # кеш коннектов до базы (connection string -> коннект до базы)
    _users_cache = {}
    # кеш, хранящий соответствие uid->shard (нужно для того, чтобы кешировать шард при создании пользователя, так как
    # шарпей отвечает не сразу про только что созданного пользователя)
    _sharpei_id_cache = {}
    # кеш, хранящий соответствие shard_id -> shard (для того, чтобы не ходить часто в шарпей)
    _common_master_engine = None
    _common_slave_engine = None
    _manual_route_shard_name = None
    _cache_lock = threading.RLock()
    _sharpei = Sharpei()

    def _get_shard_by_shard_id(self, shard_id):
        with self._cache_lock:
            shard = self._sharpei_id_cache.get(shard_id)
            if shard:
                return shard

        shard = self._sharpei.get_shard_by_id(shard_id)
        with self._cache_lock:
            self._sharpei_id_cache[shard_id] = shard
        return shard

    def _get_shard_by_uid(self, uid):
        uid = str(uid)  # Тут приходят в разных типах - приводим к строке для единообразия в кэше
        with self._cache_lock:
            shard = self._users_cache.get(uid)
            if shard:
                return shard

        shard = self._sharpei.get_shard(uid)
        with self._cache_lock:
            self._users_cache[uid] = shard
            self._sharpei_id_cache[shard.get_id()] = shard
        return shard

    def _create_engine(self, connection_string):
        params = dict(
            connect_args={
                'connection_factory': _connection_factory,
                'connect_timeout': POSTGRES_CONNECTION_ARGS_CONNECTION_TIMEOUT_SEC,
                'keepalives': POSTGRES_CONNECTION_ARGS_KEEPALIVE,
                'keepalives_idle': POSTGRES_CONNECTION_ARGS_KEEPALIVE_IDLE,
                'keepalives_interval': POSTGRES_CONNECTION_ARGS_KEEPALIVE_INTERVAL,
                'keepalives_count': POSTGRES_CONNECTION_ARGS_KEEPALIVE_COUNT,
            },
            pool_size=POSTGRES_POOL_SIZE,
            max_overflow=POSTGRES_POOL_OVERFLOW_SIZE,
        )
        engine = create_engine(connection_string, isolation_level='AUTOCOMMIT', **params)
        engine._connection_cls = Connection
        self._setup_engine(engine)
        return engine

    def _get_engine(self, connection_string):
        with self._cache_lock:
            engine = self._engines_cache.get(connection_string)
            if engine:
                return engine

        engine = self._create_engine(connection_string)
        with self._cache_lock:
            if connection_string in self._engines_cache:
                engine.dispose()
                return self._engines_cache[connection_string]
            self._engines_cache[connection_string] = engine
        return engine

    @classmethod
    @use_context(_cache_lock)
    def reset_cache(cls):
        for e in cls._engines_cache.values():
            e.dispose()

        cls._engines_cache.clear()
        cls.reset_shapei_cache()
        cls._common_master_engine = None
        cls._common_slave_engine = None
        cls._manual_route_shard_name = None

    @classmethod
    @use_context(_cache_lock)
    def reset_shapei_cache(cls):
        cls._sharpei = Sharpei()
        cls._users_cache.clear()
        cls._sharpei_id_cache.clear()

    @classmethod
    def set_manual_route(cls, shard_name):
        cls._manual_route_shard_name = shard_name

    @classmethod
    def drop_manual_route(cls):
        cls._manual_route_shard_name = None

    def get_connection_by_shard_id(self, shard_id, read_preference=ReadPreference.primary):
        host = self._get_host_by_role(self._get_shard_by_shard_id(shard_id), read_preference)
        connection_string = host.get_connection_string()
        return self.get_connection_by_connection_string(connection_string)

    def get_connection_for_all_shards(
            self, skip_unavailable_shards=False, read_preference=ReadPreference.primary, use_threads=False):

        shards = self._sharpei.get_all_shards()
        connection_strings = []
        for shard in shards:
            try:
                host = self._get_host_by_role(shard, read_preference)
            except Exception:
                if skip_unavailable_shards:
                    log.error(
                        'Something went wrong while trying to get connection string for shard `%s`' % shard.get_id())
                    continue
                raise
            connection_strings.append(host.get_connection_string())

        connection_strings_to_connect = []
        connection_strings_to_engine_map = {}
        with self._cache_lock:
            for connection_string in connection_strings:
                engine = self._engines_cache.get(connection_string)
                if engine:
                    connection_strings_to_engine_map[connection_string] = engine
                else:
                    connection_strings_to_connect.append(connection_string)

        connections = []
        for connection_string, engine in connection_strings_to_engine_map.iteritems():
            try:
                connection = self._connect_to_engine(engine, 1)  # single try because we are in main thread
            except psycopg2.OperationalError:
                connection_strings_to_connect.append(connection_string)
                continue
            connections.append(connection)

        if not connection_strings_to_connect:
            return connections

        if use_threads and len(connection_strings_to_connect) >= POSTGRES_USE_THREADPOOL_FOR_CONNECTING_THRESHOLD:
            connections.extend(
                self._get_connections_threaded(connection_strings_to_connect, skip_unavailable_shards))
        else:
            connections.extend(self._get_connections(connection_strings_to_connect, skip_unavailable_shards))

        return connections

    def _get_connections(self, connection_strings, skip_errors):
        connections = []
        for connection_string in connection_strings:
            try:
                connections.append(self.get_connection_by_connection_string(connection_string))
            except Exception as e:
                if skip_errors:
                    log.error(
                        'Something went wrong while trying to connect %s' % connection_string
                    )
                    continue
                raise
        return connections

    def _get_connections_threaded(self, connection_strings, skip_errors):
        def get_connection(connection_string):
            try:
                return connection_string, self.get_connection_by_connection_string(connection_string), None
            except Exception as e:
                return connection_string, None, e

        pool = ThreadPool(processes=len(connection_strings))
        connections = []
        try:
            it = pool.imap_unordered(get_connection, connection_strings)
            for connection_string, connection, err in it:
                if connection is None:
                    if skip_errors:
                        log.error(
                            'Something went wrong while trying to connect %s' % connection_string
                        )
                        continue
                    raise err
                connections.append(connection)
        finally:
            pool.terminate()
        return connections

    def get_connection_to_common(self, read_preference=ReadPreference.primary):
        if self._common_master_engine is None or self._common_slave_engine is None:
            self._setup_common_shard_connections()
        if read_preference in (ReadPreference.primary, ReadPreference.primary_preferred):
            return self._common_master_engine
        else:
            return self._common_slave_engine

    def get_connection_by_connection_string(self, connection_string):
        engine = self._get_engine(connection_string)
        return self._connect_to_engine(engine, POSTGRES_RECONNECTION_ATTEMPTS)

    def get_shard_id(self, uid):
        if self._manual_route_shard_name:
            return self._manual_route_shard_name
        return self._get_shard_by_uid(uid).get_id()

    def is_user_in_postgres(self, uid):
        try:
            return bool(self._get_shard_by_uid(uid))
        except SharpeiUserNotFoundError:
            return False

    @staticmethod
    def _get_host_by_role(shard, read_preference):
        if read_preference == ReadPreference.primary:
            return shard.get_master()
        elif read_preference == ReadPreference.primary_preferred:
            try:
                return shard.get_master()
            except MasterNotFoundError:
                return shard.get_random_slave(
                    filter_dc=POSTGRES_SHARPEI_REPLICA_FORBIDDEN_DATACENTERS,
                    replication_lag_theshold=POSTGRES_SHARPEI_REPLICATION_LAG_THESHOLD
                )
        elif read_preference == ReadPreference.secondary:
            return shard.get_random_slave(
                filter_dc=POSTGRES_SHARPEI_REPLICA_FORBIDDEN_DATACENTERS,
                replication_lag_theshold=POSTGRES_SHARPEI_REPLICATION_LAG_THESHOLD
            )
        elif read_preference == ReadPreference.secondary_preferred:
            try:
                return shard.get_random_slave(
                    filter_dc=POSTGRES_SHARPEI_REPLICA_FORBIDDEN_DATACENTERS,
                    replication_lag_theshold=POSTGRES_SHARPEI_REPLICATION_LAG_THESHOLD
                )
            except SlavesNotFoundError:
                return shard.get_master()
        return None

    @staticmethod
    def _connect_to_engine(engine, try_count):
        attempt = 0
        while True:
            try:
                return engine.connect()
            except psycopg2.OperationalError:
                if attempt < try_count:
                    attempt += 1
                else:
                    raise

    def create_user(self, uid, shard_id=None):
        uid = int(uid)  # можно создавать только цифровые uid'ы
        shard = self._sharpei.create_user(uid, shard_id=shard_id)
        with self._cache_lock:
            self._users_cache[str(uid)] = shard  # в кеше уиды строчные
            self._sharpei_id_cache[shard.get_id()] = shard

    def get_all_shard_ids(self):
        # предпочтительно использовать get_all_shard_endpoints
        return self._sharpei.get_all_shard_ids()

    def get_all_shard_endpoints(self):
        return [ShardEndpoint(ShardType.POSTGRES, i) for i in self.get_all_shard_ids()]

    @staticmethod
    def _setup_engine(engine):
        event.listen(engine, 'before_execute', before_cursor_execute)
        event.listen(engine, 'after_execute', receive_after_execute)

    def _get_common_shard_connection_strings(self):
        common_shard = self._sharpei.get_common_shard()
        master_connstring = common_shard.get_master().get_connection_string()
        slave_connstring = common_shard.get_random_slave().get_connection_string()
        return master_connstring, slave_connstring

    def _setup_common_shard_connections(self):
        master_connstring, slave_connstring = self._get_common_shard_connection_strings()
        self._common_master_engine = self.get_connection_by_connection_string(master_connstring)
        self._common_slave_engine = self.get_connection_by_connection_string(slave_connstring)
