# -*- coding: utf-8 -*-
import contextlib
import logging

from concurrent.futures import TimeoutError
import ydb

logger = logging.getLogger(__name__)


class TooManyFunctionCallsException(Exception):
    pass


SESSION_POOL_SIZE = 1


class YdbSessionManager(object):
    def __init__(self, driver_config, pool_size=SESSION_POOL_SIZE, max_calls=10,
                 recreate_ydb_session_pool=lambda: True):
        self._driver_config = driver_config  # type: ydb.DriverConfig
        self._retry_settings = ydb.RetrySettings()
        self._session_pool_size = pool_size
        self._session_pool = None
        self._max_calls = max_calls
        self.recreate_ydb_session_pool = recreate_ydb_session_pool

    @contextlib.contextmanager
    def get_session_pool(self):
        if self._session_pool is None:
            self._create_new_session_pool()

        yield self._session_pool

    @property
    def database(self):
        return self._driver_config.database

    def _create_new_session_pool(self):
        logger.info('Create ydb session pool (size=%s)', self._session_pool_size)
        ydb_driver = ydb.Driver(self._driver_config)
        try:
            ydb_driver.wait(5)
        except TimeoutError as e:
            logger.error(
                'Connect failed to YDB: %s. Last reported errors by discovery: %s',
                e.message,
                ydb_driver.discovery_debug_details()
            )
        self._session_pool = ydb.SessionPool(ydb_driver, size=self._session_pool_size)
        self._session_pool.retry_operation_sync = self._retry_operation_sync(
            self._session_pool.retry_operation_sync
        )

    def _retry_operation_sync(self, orig_retry_operation_sync):
        """
        Рекомендации по обработке ошибок
        https://ydb.yandex-team.ru/docs/api/error_handling
        :param func:
        :return:
        """

        def wrapper(callee, *args, **kwargs):
            if getattr(callee, 'calls_count', None):
                callee.calls_count += 1
                if callee.calls_count > self._max_calls:
                    raise TooManyFunctionCallsException(
                        '%r called %d times.' % (callee, self._max_calls)
                    )
            else:
                callee.calls_count = 1

            try:
                return orig_retry_operation_sync(callee, *args, **kwargs)
            except (
                ydb.issues.BadSession,
                ydb.issues.SessionBusy,
                ydb.issues.SessionExpired,
            ) as e:
                if self.recreate_ydb_session_pool():
                    logger.warning('Creating new ydb.SessionPool after exception: %r', e)
                    self._create_new_session_pool()

                with self.get_session_pool() as session_pool:
                    session_pool.retry_operation_sync(callee, *args, **kwargs)
            except (
                ydb.issues.Timeout,
                ydb.issues.Aborted,
                ydb.issues.Unavailable,
                ydb.issues.Overloaded,
                ydb.issues.Undetermined,
            ) as e:
                logger.warning('Retrying YDB request after exception: %r', e)
                with self.get_session_pool() as session_pool:
                    session_pool.retry_operation_sync(callee, *args, **kwargs)
            except Exception as e:
                logger.exception('Unhandled YDB exception %r', e)
                raise

        return wrapper
