# -*- coding: utf-8 -*-
from concurrent.futures import TimeoutError
from contextlib import contextmanager
import json
import logging
import time

from passport.backend.core.conf import settings
from passport.backend.core.lazy_loader import (
    lazy_loadable,
    LazyLoader,
)
from passport.backend.core.logging_utils.loggers import GraphiteLogger
from passport.backend.core.tvm import get_tvm_credentials_manager
from passport.backend.core.utils.decorators import cached_property
from passport.backend.core.ydb.exceptions import (
    BaseYdbError,
    YdbGenericError,
    YdbInstanceNotAvailable,
    YdbInvalidResponseError,
    YdbMissingKeyColumnsError,
    YdbMultipleResultFound,
    YdbNoResultFound,
    YdbPreconditionError,
    YdbSessionInvalidError,
    YdbTemporaryError,
    YdbUnknownKeyColumnsError,
)
import passport.backend.core.ydb_client as ydb
from passport.backend.core.ydb_client import (
    BaseRequestSettings,
    RetrySettings,
)
from passport.backend.utils.common import merge_dicts
import six
from six import iteritems
from six.moves import map


log = logging.getLogger('passport.ydb')

YdbOnlineReadOnlyTxMode = ydb.OnlineReadOnly
YdbSerializableReadWrite = ydb.SerializableReadWrite


# У нас синхронные воркеры, где воркер обрабатывает 1 пользовательский запрос и делает
# 1 последовательный запрос в YDB, т.е. в теории должно быть достаточно поставить значение 1.
# В чате поддержки YDB мне посоветовали поставить 4 для верности. Спорить не стал.
YDB_MAX_SESSION_POOL_SIZE = 4

# Клиент YDB сам в фоне будет стараться поддерживать минимальное количество сессий
YDB_MIN_SESSION_POOL_SIZE = 1


def ydb_str_exception(e):
    """Выковыривалка данных об ошибке из исключений YDB"""
    if isinstance(e, (ydb.Error,)):
        status = getattr(e, 'status', '')
        issues = getattr(e, 'issues', '')
        message = getattr(e, 'message', '')
        return 'error: {}; status: {}; issues: {}; message: {}'.format(
            e.__class__.__name__,
            status,
            issues,
            message,
        )
    else:
        return str(e)


class TvmCredentialsProvider(ydb.credentials.Credentials):
    def __init__(self, tvm_dst_alias):
        self._tvm_dst_alias = tvm_dst_alias
        self._tvm_credentials_manager = get_tvm_credentials_manager()

    def _get_service_ticket(self):
        return six.text_type(self._tvm_credentials_manager.get_ticket_by_alias(self._tvm_dst_alias))

    def expired(self):
        return True  # чтобы ходить в tvm_credentials_manager на каждый запрос

    def auth_metadata(self):
        return [
            (
                ydb.credentials.YDB_AUTH_TICKET_HEADER, self._get_service_ticket(),
            ),
        ]


class Ydb(object):
    def __init__(
        self,
        endpoint,
        database,
        enabled,
        auth_token=None,
        use_tvm=False,
        graphite_logger=None,
        max_session_pool_size=None,
        min_session_pool_size=None,
        connection_timeout=None,
        get_session_timeout=None,
    ):
        if not enabled:
            raise YdbInstanceNotAvailable(
                '{} is disabled in current environment'.format(self.__class__.__name__),
            )
        self._endpoint = endpoint
        self._database = database
        self.graphite_logger = graphite_logger or GraphiteLogger(service='ydb')
        self.get_session_timeout = get_session_timeout or settings.YDB_GET_SESSION_TIMEOUT

        self._prepared_statements = {}

        auth_kwargs = {}
        if use_tvm:
            auth_kwargs.update(credentials=TvmCredentialsProvider(tvm_dst_alias='ydb'))
        elif auth_token:
            auth_kwargs.update(auth_token=auth_token)

        self.driver_config = ydb.DriverConfig(
            endpoint=self._endpoint,
            database=self._database,
            **auth_kwargs
        )

        self.driver = ydb.Driver(driver_config=self.driver_config)
        try:
            self.driver.wait(connection_timeout or settings.YDB_CONNECTION_TIMEOUT)
        except TimeoutError:
            pass

        self._session_pool = ydb.SessionPool(
            driver=self.driver,
            size=max_session_pool_size or YDB_MAX_SESSION_POOL_SIZE,
            min_pool_size=min_session_pool_size or YDB_MIN_SESSION_POOL_SIZE,
        )

    def _reset_prepared_statements_cache(self, _):
        self._prepared_statements = {}

    @cached_property
    def session_pool(self):
        return self._session_pool

    @cached_property
    def endpoint(self):
        return self._endpoint.strip('/')

    @cached_property
    def database(self):
        return self._database.strip('/')

    @contextmanager
    def ydb_wrapper(self):
        # Транслирует ошибки ydb в паспортные ошибки, логирует в графитный лог
        # при логировании не учитываются ретраи
        start_time = time.time()
        status = 'ok'
        response_code = 'success'
        try:
            yield
        except (
            TimeoutError,
            ydb.ConnectionError,
            ydb.DeadlineExceed,
            ydb.Overloaded,
            ydb.SessionPoolEmpty,
            ydb.Timeout,
        ) as e:
            status = e.__class__.__name__
            response_code = 'timeout'
            raise YdbTemporaryError(ydb_str_exception(e))
        except (ydb.SessionExpired, ydb.BadSession) as e:
            status = e.__class__.__name__
            response_code = 'timeout'
            raise YdbSessionInvalidError(ydb_str_exception(e))
        except ydb.PreconditionFailed as e:
            status = e.__class__.__name__
            response_code = 'failed'
            raise YdbPreconditionError(ydb_str_exception(e))
        except BaseYdbError as e:
            status = e.__class__.__name__
            response_code = 'failed'
            raise
        except Exception as e:
            status = e.__class__.__name__
            response_code = 'failed'
            raise YdbGenericError(ydb_str_exception(e), e)
        finally:
            self.graphite_logger.log(
                duration=time.time() - start_time,
                response=response_code,
                status=status,
            )

    def get_prepared_statement(self, statement_key):
        return self._prepared_statements.get(statement_key)

    def prepare_statement(self, session, statement_key, statement):
        prepared_statement = session.prepare(statement)
        self._prepared_statements[statement_key] = prepared_statement
        return prepared_statement

    def session_call(self, f, retry_settings=None):
        # Выполняет функцию операцию в ydb. Оборачиваемая функция должна принимать сессию как первый аргумент
        # retry_settings - настройки ретраев ydb

        if retry_settings:
            retry_settings.on_ydb_error_callback(self._reset_prepared_statements_cache)
        else:
            retry_settings = RetrySettings(
                on_ydb_error_callback=self._reset_prepared_statements_cache,
                max_retries=settings.YDB_RETRIES,
                get_session_client_timeout=self.get_session_timeout,
            )

        with self.ydb_wrapper():
            return self.session_pool.retry_operation_sync(
                f,
                retry_settings=retry_settings,
            )


class YdbKeyValue(object):
    def __init__(
        self, endpoint, database, table_name, key_columns, value_column, enabled,
        value_type='Json',
        auth_token=None, use_tvm=False, retries=None, timeout=None,
        deadline=None,
        connection_timeout=None,
        get_session_timeout=None,
        graphite_logger=None,
        limit=None,
    ):
        if not enabled:
            raise YdbInstanceNotAvailable(
                '{} is disabled in current environment'.format(self.__class__.__name__),
            )
        self._ydb = Ydb(
            endpoint=endpoint,
            database=database,
            auth_token=auth_token,
            use_tvm=use_tvm,
            enabled=True,
            graphite_logger=graphite_logger,
            connection_timeout=connection_timeout,
            get_session_timeout=get_session_timeout,
        )

        self._table_name = table_name
        self._key_columns = key_columns
        self._value_column = value_column
        self._value_type = value_type
        self._limit = limit

        # таймаут - ограничение по времени на клиентской стороне
        timeout = timeout or settings.YDB_TIMEOUT
        # deadline - ограничение по времени на серверной стороне. Если сервер не успел обработать, он просто забьёт.
        deadline = deadline or settings.YDB_DEADLINE

        self._query_settings = ydb.BaseRequestSettings()
        self._query_settings.with_timeout(timeout=timeout)
        self._query_settings.with_timeout(timeout=timeout)
        self._query_settings.with_operation_timeout(timeout=deadline)

        self._retry_settings = ydb.RetrySettings(
            max_retries=retries or settings.YDB_RETRIES,
            get_session_client_timeout=get_session_timeout or settings.YDB_GET_SESSION_TIMEOUT,
        )

    def session_call(self, f):
        return self._ydb.session_call(f, retry_settings=self._retry_settings)

    @property
    def endpoint(self):
        return self._ydb.endpoint

    def _compose_read_statement(self, keys, with_declare_statement=True):
        declare_statement = []
        limit_statement = ''
        where_condition = []
        for key in sorted(keys):
            if key not in self._key_columns:
                raise YdbUnknownKeyColumnsError(key)
            key_type = self._key_columns[key]
            declare_statement.append('declare ${key} as {key_type};'.format(key=key, key_type=key_type))
            where_condition.append('{key} = ${key}'.format(key=key))

        declare_statement = '\n'.join(declare_statement)
        if self._limit is not None:
            limit_statement = 'LIMIT {limit}'.format(
                limit=self._limit,
            )
        where_condition = ' and '.join(where_condition)
        return (
            ('{declare_statement}\n' if with_declare_statement else '') +
            'select {value_column} from [{table_name}] where {where_condition}' +
            (' {limit_statement}' if limit_statement else '')
            + ';\n'
        ).format(
            table_name=self._table_name,
            declare_statement=declare_statement,
            limit_statement=limit_statement,
            where_condition=where_condition,
            value_column=self._value_column,
        )

    @property
    def name(self):
        return '%s/%s/%s' % (self._ydb.endpoint, self._ydb.database, self._table_name.strip('/'))

    def _compose_write_statement(self):
        declare_statement = []
        for key_column, key_type in sorted(iteritems(self._key_columns)):
            declare_statement.append('declare ${key_column} as {key_type};'.format(
                key_column=key_column,
                key_type=key_type,
            ))
        declare_statement = '\n'.join(declare_statement)

        key_columns_names = ', '.join(sorted(self._key_columns.keys()))
        expensive_key_columns_names = ', '.join(map(lambda x: '$' + x, sorted(self._key_columns.keys())))

        return (
            '{declare_statement}\n'
            'declare ${value_column} as {value_type};\n'
            'upsert into [{table_name}] ({key_columns_names}, {value_column})\n'
            'values ({expensive_key_columns_names}, ${value_column});\n'
        ).format(
            declare_statement=declare_statement,
            table_name=self._table_name,
            key_columns_names=key_columns_names,
            expensive_key_columns_names=expensive_key_columns_names,
            value_column=self._value_column,
            value_type=self._value_type,
        )

    def _compose_delete_statement(self, keys):
        declare_statement = []
        where_condition = []
        for key in sorted(keys):
            if key not in self._key_columns:
                raise YdbUnknownKeyColumnsError(key)
            key_type = self._key_columns[key]
            declare_statement.append('declare ${key} as {key_type};'.format(key=key, key_type=key_type))
            where_condition.append('{key} = ${key}'.format(key=key))

        declare_statement = '\n'.join(declare_statement)
        where_condition = ' and '.join(where_condition)

        return (
            '{declare_statement}\n' +
            'delete from [{table_name}] where {where_condition};\n'
        ).format(
            table_name=self._table_name,
            declare_statement=declare_statement,
            where_condition=where_condition,
        )

    def _postprocess_value(self, value):
        try:
            if value is None:
                return None
            if self._value_type == 'Json':
                return json.loads(value)
            elif self._value_type in ('Uint64', 'Int64', 'Uint32', 'Int32', 'Uint16', 'Int16', 'Uint8', 'Int8'):
                return int(value)
            elif self._value_type in ('Float', 'Double'):
                return float(value)
            else:
                return value
        except (ValueError, TypeError):
            raise YdbInvalidResponseError(str(value))

    def get(self, keys):
        def _get(session):
            statement_key = self._build_read_statement_key(keys.keys())
            prepared_read_statement = self._ydb.get_prepared_statement(statement_key)
            if not prepared_read_statement:
                prepared_read_statement = self._ydb.prepare_statement(
                    session,
                    statement_key,
                    self._compose_read_statement(keys),
                )

            result = self._kikimr_get(session, keys, prepared_read_statement)

            if len(result) == 0 or not result[0].rows:
                return iter(())
            return map(lambda x: self._postprocess_value(getattr(x, self._value_column)), result[0].rows)

        return self.session_call(_get)

    @staticmethod
    def _build_read_statement_key(keys):
        return ','.join(['read'] + sorted(keys))

    def _kikimr_get(self, session, keys, prepared_read_statement):
        return (
            session.transaction(ydb.StaleReadOnly())
            .execute(
                prepared_read_statement,
                parameters={'$' + key_column: keys[key_column] for key_column in keys},
                commit_tx=True,
                settings=self._query_settings,
            )
        )

    def set(self, keys, value):
        def _set(session):
            if keys.keys() != self._key_columns.keys():
                raise YdbMissingKeyColumnsError(set(self._key_columns.keys()).difference(keys.keys()))

            statement_key = self._build_write_statement_key()
            prepared_write_statement = self._ydb.get_prepared_statement(statement_key)
            if not prepared_write_statement:
                self._ydb.prepare_statement(
                    session,
                    statement_key,
                    self._compose_write_statement(),
                )

            result = self._kikimr_set(session, keys, value)
            return result

        return self.session_call(_set)

    @staticmethod
    def _build_write_statement_key():
        return 'write'

    def _kikimr_set(self, session, keys, value):
        prepared_write_statement = self._ydb.get_prepared_statement(self._build_write_statement_key())
        return session.transaction(ydb.SerializableReadWrite()).execute(
            prepared_write_statement,
            parameters=(
                merge_dicts(
                    {'$' + key_column: keys[key_column] for key_column in keys},
                    {'$' + self._value_column: value}
                )
            ),
            commit_tx=True,
            settings=self._query_settings,
        )

    def delete(self, keys):
        def _delete(session):
            statement_key = self._build_delete_statement_key(keys.keys())
            prepared_delete_statement = self._ydb.get_prepared_statement(statement_key)
            if not prepared_delete_statement:
                prepared_delete_statement = self._ydb.prepare_statement(
                    session,
                    statement_key,
                    self._compose_delete_statement(keys),
                )

            self._kikimr_delete(session, keys, prepared_delete_statement)

        return self.session_call(_delete)

    @staticmethod
    def _build_delete_statement_key(keys):
        return ','.join(['delete'] + sorted(keys))

    def _kikimr_delete(self, session, keys, prepared_delete_statement):
        return (
            session.transaction(ydb.SerializableReadWrite())
            .execute(
                prepared_delete_statement,
                parameters={'$' + key_column: keys[key_column] for key_column in keys},
                commit_tx=True,
                settings=self._query_settings,
            )
        )

    def first(self, keys, default_value=None):
        try:
            results = self.get(keys)
            return next(results)
        except StopIteration:
            return default_value


@lazy_loadable()
class YdbProfile(YdbKeyValue):
    def __init__(
        self, endpoint=None, database=None, table_name=None,
        auth_token=None, use_tvm=None, retries=None, timeout=None,
        connection_timeout=None, get_session_timeout=None, limit=None,
        enabled=None,
    ):
        super(YdbProfile, self).__init__(
            endpoint=endpoint or settings.YDB_ENDPOINT,
            database=database or settings.YDB_DATABASE,
            table_name=table_name or settings.YDB_PROFILE_TABLE,
            enabled=enabled if enabled is not None else settings.YDB_PROFILE_ENABLED,
            key_columns={
                'uid': 'Uint64',
                'inverted_event_timestamp': 'Uint64',
                'unique_id': 'Uint64',
                'updated_at': 'Uint64',
            },
            value_column='value',
            use_tvm=use_tvm if use_tvm is not None else settings.YDB_USE_TVM,
            auth_token=auth_token or settings.YDB_TOKEN,
            retries=retries or settings.YDB_RETRIES,
            timeout=timeout or settings.YDB_TIMEOUT,
            connection_timeout=connection_timeout or settings.YDB_CONNECTION_TIMEOUT,
            get_session_timeout=get_session_timeout or settings.YDB_GET_SESSION_TIMEOUT,
            limit=limit or settings.YDB_READ_LIMIT,
        )

    def get_profile(self, uid):
        return self.get({'uid': uid})


def get_ydb_profile():
    return LazyLoader.get_instance('YdbProfile')


@lazy_loadable()
class YdbDriveSession(YdbKeyValue):
    def __init__(
        self,
        endpoint=None,
        database=None,
        auth_token=None,
        use_tvm=None,
        retries=None,
        timeout=None,
        connection_timeout=None,
        get_session_timeout=None,
        enabled=None,
    ):
        super(YdbDriveSession, self).__init__(
            key_columns=dict(drive_device_id='Utf8'),
            value_column='value',
            endpoint=endpoint or settings.YDB_ENDPOINT,
            database=database or settings.YDB_DRIVE_DATABASE,
            enabled=enabled if enabled is not None else settings.YDB_DRIVE_ENABLED,
            table_name='drive_session',
            use_tvm=use_tvm if use_tvm is not None else settings.YDB_USE_TVM,
            auth_token=auth_token or settings.YDB_TOKEN,
            retries=retries or settings.YDB_RETRIES,
            timeout=timeout or settings.YDB_TIMEOUT,
            connection_timeout=connection_timeout or settings.YDB_CONNECTION_TIMEOUT,
            get_session_timeout=get_session_timeout or settings.YDB_GET_SESSION_TIMEOUT,
        )


def get_ydb_drive_session():
    return LazyLoader.get_instance('YdbDriveSession')


@lazy_loadable()
class YdbSupportCode(Ydb):
    def __init__(
        self,
        endpoint=None,
        database=None,
        auth_token=None,
        use_tvm=None,
        graphite_logger=None,
        enabled=None,
    ):
        super(YdbSupportCode, self).__init__(
            endpoint=endpoint or settings.YDB_ENDPOINT,
            database=database or settings.YDB_SUPPORT_CODE_DATABASE,
            enabled=enabled if enabled is not None else settings.YDB_SUPPORT_CODE_ENABLED,
            use_tvm=use_tvm if use_tvm is not None else settings.YDB_USE_TVM,
            auth_token=auth_token or settings.YDB_TOKEN,
            graphite_logger=graphite_logger,
        )


def get_ydb_support_code():
    return LazyLoader.get_instance('YdbSupportCode')


@lazy_loadable()
class YdbFamilyInvite(Ydb):
    def __init__(
        self,
        endpoint=None,
        database=None,
        auth_token=None,
        use_tvm=None,
        graphite_logger=None,
        enabled=None,
    ):
        super(YdbFamilyInvite, self).__init__(
            endpoint=endpoint or settings.YDB_ENDPOINT,
            database=database or settings.YDB_FAMILY_INVITE_DATABASE,
            enabled=enabled if enabled is not None else settings.YDB_FAMILY_INVITE_ENABLED,
            use_tvm=use_tvm if use_tvm is not None else settings.YDB_USE_TVM,
            auth_token=auth_token or settings.YDB_TOKEN,
            graphite_logger=graphite_logger,
        )


def get_ydb_family_invite():
    return LazyLoader.get_instance('YdbFamilyInvite')


class YdbQueryExecutor(object):
    def __init__(self, ydb, retry_settings=None):
        self._ydb = ydb
        self.retry_settings = retry_settings

    def execute_queries(self, queries, tx_mode=None, retry_settings=None, query_settings=None):
        def _execute(session):
            result_sets = []
            with session.transaction(tx_mode=tx_mode) as tx:
                len_queries = len(queries)
                for i, query in enumerate(queries):
                    raw_statement = query.get_raw_statement()
                    # Подготовленный запрос хранится в сессии, его нужно искать в ней
                    # каждый перезапрос, т.к. сессия может поменяться из-за протухания.
                    prepared_statement = self._ydb.get_prepared_statement(raw_statement)
                    if not prepared_statement:
                        prepared_statement = self._ydb.prepare_statement(session, raw_statement, raw_statement)

                    parameters = query.get_parameters()
                    # Передавать commit_tx в последнем запросе вместо явного вызова tx.commit()
                    # это рекомендация коллег из ydb.
                    # Идея, стоящая за этими действиями:
                    #   Сервер знает, что это последний запрос в транзакции и может опустить взятие
                    #   локов или сделать ещё какие оптимизации на своей стороне.
                    #   Если всегда commit_tx=False, то сервер не может делать таких оптимизаций и ждёт
                    #   явного tx.commit(), который тоже занимает время.
                    # Для некоторых tx_mode вообще стоит явный запрет на использование commit
                    commit_tx = i >= len_queries - 1
                    result_set = tx.execute(
                        prepared_statement,
                        parameters=parameters,
                        settings=query_settings or query.settings,
                        commit_tx=commit_tx,
                    )
                    result_sets.extend(result_set)
            return [YdbResultSet.from_ydb_result_set(query, rs) for rs in result_sets]

        return self._ydb.session_call(_execute, retry_settings=retry_settings or self.retry_settings)

    def _prepare_statement(self, raw_statement):
        with self._ydb.ydb_wrapper():
            return self._ydb.prepare_statement(raw_statement, raw_statement)


class YdbResultSet(six.Iterator):
    def __init__(self, query, length, iterator):
        self._query = query
        self._length = length
        self._iterator = iterator

    def __len__(self):
        return self._length

    def __iter__(self):
        return self

    def __next__(self):
        value = next(self._iterator)
        return self._query.parse_query_result(value)

    @classmethod
    def from_ydb_result_set(cls, query, result_set):
        return YdbResultSet(query, len(result_set.rows), iter(result_set.rows))

    def one(self):
        try:
            value = next(self)
        except StopIteration:
            raise YdbNoResultFound()
        try:
            next(self)
            raise YdbMultipleResultFound()
        except StopIteration:
            return value


class YdbQuery(object):
    def __init__(
        self,
        raw_statement=None,
        parameters=None,
        query_settings=None,
    ):
        self.raw_statement = raw_statement
        self.parameters = parameters

        if query_settings is None:
            timing_settings = BaseRequestSettings()
            timing_settings.with_timeout(settings.YDB_TIMEOUT)
            timing_settings.with_operation_timeout(settings.YDB_DEADLINE)

        self.settings = query_settings

    def get_raw_statement(self):
        return self.raw_statement

    def get_parameters(self):
        return self.parameters

    def parse_query_result(self, result):
        return result
