# -*- coding: utf-8 -*-
import MySQLdb
from passport.backend.core.conf import settings
from passport.backend.core.lazy_loader import (
    lazy_loadable,
    LazyLoader,
)
from passport.backend.perimeter.auth_api.common.exceptions import DBError
from sqlalchemy import create_engine
from sqlalchemy.engine.url import URL
import sqlalchemy.exc
from sqlalchemy.sql.elements import ClauseElement


DB_EXCEPTIONS = (
    sqlalchemy.exc.DatabaseError,
    sqlalchemy.exc.InterfaceError,
    MySQLdb.InterfaceError,
    MySQLdb.DatabaseError,
)


FETCH_ONE = 'one'
FETCH_MANY = 'many'
FETCH_NONE = None


@lazy_loadable()
class DBConnection(object):
    def __init__(self, host=None, port=None, username=None,
                 password=None, driver=None, database=None,
                 retries=None, connect_args=None):
        self.host = host or settings.DB_HOST
        self.port = port or settings.DB_PORT
        self.username = username or settings.DB_USER
        self.password = password or settings.DB_PASSWORD
        self.driver = driver or settings.DB_DRIVER
        self.database = database or settings.DB_DATABASE
        self.retries = retries or settings.DB_RETRIES
        self.connect_args = connect_args or settings.DB_CONNECT_ARGS
        self._connection = None
        self._engine = None

    @property
    def dsn(self):
        return URL(
            drivername=self.driver,
            username=self.username,
            password=self.password,
            host=self.host,
            port=self.port,
            database=self.database,
        )

    def connect(self):
        """
        Устанавливаем соединение с БД.
        TODO: добавить поддержку ретраев.
        :return:
        """
        if not self._engine:
            self._engine = create_engine(
                self.dsn,
                connect_args=self.connect_args,
            )

        try:
            self._connection = self._engine.connect()
        except DB_EXCEPTIONS as exc:
            raise DBError(str(exc)) from exc

    def initialize_schema(self, metadata):
        """
        Заводим в БД известные нам таблицы.
        """
        metadata.create_all(self._engine)

    def _execute(self, query, args=None, fetch=FETCH_ONE):
        """
        Внутренняя функция выполнения скомпилированного запроса. Если
        установлен параметр many, то возвращаем список со всеми
        возвращенными строками. В противном случае возвращаем ровно
        содержимое единственной.
        """

        if not isinstance(query, ClauseElement):
            raise ValueError('Raw queries are not supported.')  # pragma: no cover

        result = self._connection.execute(query.execution_options(autocommit=True), args or [])
        if fetch == FETCH_MANY:
            rows = result.fetchall()
        elif fetch == FETCH_ONE:
            rows = result.fetchone()
        else:
            rows = None
        result.close()
        return rows

    def execute(self, query, args=None, many=False):
        """
        Публичная функция выполнения скомпилированного запроса.
        Пытается N раз выполнить переданный вопрос, после чего
        громко падает с ошибкой. N указывается параметром retries
        при создании объекта.
        """
        if not self._connection:
            self.connect()

        for i in range(self.retries):
            try:
                return self._execute(
                    query,
                    args=args,
                    fetch=FETCH_MANY if many else FETCH_ONE,
                )
            except DB_EXCEPTIONS as exc:
                if i >= self.retries - 1:
                    raise DBError(str(exc)) from exc

    def disconnect(self):
        """
        Закрытие соединения с БД.
        """
        if self._connection:
            self._connection.close()
            self._connection = None

    def __repr__(self):
        fields = [
            'driver',
            'host',
            'port',
            'username',
            'password',
            'database',
            'retries',
        ]

        return '%s(%s)' % (
            self.__class__.__name__,
            ', '.join(
                [
                    '%s=%s' % (field, repr(getattr(self, field)))
                    for field in fields
                ],
            ),
        )

    def __eq__(self, other):
        return all([
            self.dsn == other.dsn,
            self.retries == other.retries,
        ])


def get_db_connection():
    return LazyLoader.get_instance('DBConnection')
