from __future__ import annotations

import abc
import asyncio
import inspect
import logging
import sys
from datetime import timedelta
from logging import Logger
from types import TracebackType
from typing import (
    Any, AsyncContextManager, AsyncIterable, ClassVar, Dict, Generic, Iterable, Optional, Type, TypeVar, Union
)

import psycopg2.errors
from aiopg.sa.connection import SAConnection
from aiopg.sa.transaction import Transaction
from aiopg.utils import _TransactionContextManager
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.selectable import GenerativeSelect

from sendr_aiopg.query_builder import CRUDQueries
from sendr_aiopg.storage.entity import Entity, EntityProtocol  # NOQA
from sendr_aiopg.types import EngineUnion, OptException
from sendr_qlog import LoggerContext

default_logger = logging.getLogger(__name__)

T = TypeVar('T', bound="EntityProtocol")
T2 = TypeVar('T2', bound="StorageBase")


class QueryMixin:
    conn: SAConnection

    async def _query(self, query: Union[GenerativeSelect, Executable], *,
                     offset: Optional[int] = None,
                     limit: Optional[int] = None,
                     iterator: bool = False) -> AsyncIterable[Any]:
        if offset:
            query = query.offset(offset)
        if limit:
            query = query.limit(limit)

        iter_ = self.conn.execute(query)
        if not iterator:
            iter_ = [row async for row in iter_]
            for row in iter_:
                yield row
        else:
            async for row in iter_:
                yield row

    async def _query_one(self, query: Executable, raise_: OptException = None) -> Any:
        db_result = await self.conn.execute(query)
        row = await db_result.first()

        if row is None and raise_:
            raise raise_
        return row

    async def _query_scalar(self, query: Executable, raise_: OptException = None) -> Any:
        db_result = await self.conn.execute(query)
        result = await db_result.scalar()
        if result is None and raise_:
            raise raise_
        return result


class BaseGateway(QueryMixin, metaclass=abc.ABCMeta):
    """
    Базовый класс гейтвеев.
    Гейтвей предназначен для реализации действий в базе данных,
    не относящихся напрямую к CRUD, либо реализующих нетривиальную логику.
    Например:
    — агрегация данных
    — массовое изменение записей
    — запросы для сложной фильтрации, возвращающие идентификаторы записей

    Предпочтительно в реализацией использовать одну таблицу в одном классе гейтвея
    """

    def __init__(self, connection: SAConnection, logger: Logger = default_logger):
        self.conn = connection
        self._logger = logger

    @property
    @abc.abstractmethod
    def table(self):
        pass

    @property
    def c(self):
        return self.table.c


class BaseMapper(Generic[T], QueryMixin):
    """
    Базовый класс для мапперов.
    Маппер предназначен для реализации обращений к БД и маппинга записей
    в объекты модели данных приложения

    В качестве основных методов реализованы методы:
    — create - принимает объект модели и создает запись в базе
    — find - асинхронный генератор, возвращающий объекты, полученные из отобранных на основе параметров записей
    — get - на основании параметров возвращает один объект, соответствующий конкретной записи базы
    — save - принимает объект и обновляет соответствующую запись в базе
    — delete - принимает объект и удаляет соответствующую запись в базе

    Методы наследников могут отличаться. Допускается переопределять методы, подходящие к логике работы приложения.
    """

    _builder: ClassVar[CRUDQueries]
    model: ClassVar[Type[T]]

    def __init__(self, connection: SAConnection, logger: Logger = default_logger):
        self.conn = connection
        self._logger = logger


class BaseMapperCRUD(BaseMapper, Generic[T]):
    _unique_violation_error_mapping: ClassVar[Dict[str, Type[Exception]]] = dict()

    def _try_map_error(self, exception: psycopg2.errors.UniqueViolation) -> None:
        constraint_name = exception.diag.constraint_name
        if constraint_name in self._unique_violation_error_mapping:
            raise self._unique_violation_error_mapping[constraint_name]

    async def _query_one(self, query: Executable, raise_: OptException = None) -> Any:
        try:
            return await super()._query_one(query, raise_=raise_)
        except psycopg2.errors.UniqueViolation as e:
            self._try_map_error(e)
            raise

    async def create(self, item: T, *args: Any, **kwargs: Any) -> T:
        query, mapper = self._builder.insert(item, *args, **kwargs)
        return mapper(await self._query_one(query))

    async def find(self, *args: Any, **kwargs: Any) -> AsyncIterable[T]:
        query, mapper = self._builder.select(*args, **kwargs)
        async for row in self._query(query):
            yield mapper(row)

    async def get(self, *args: Any, for_update: bool = False) -> T:
        query, mapper = self._builder.select(id_values=args, for_update=for_update)
        exception = self.model.DoesNotExist
        return mapper(await self._query_one(query, raise_=exception))

    async def save(self, obj: T, ignore_fields: Optional[Iterable[str]] = None) -> T:
        query, mapper = self._builder.update(obj, ignore_fields=ignore_fields)
        return mapper(await self._query_one(query))

    async def delete(self, obj: T) -> None:
        query = self._builder.delete(obj)
        return await self._query_one(query)


class StorageAnnotatedMeta(type):
    """
    Превращает аннотации типов для мапперов в дескрипторы,
    которые лениво создают экземпляры мапперов на сторедже при первом обращении.
    """

    BASE_INTERFACES = (BaseGateway, BaseMapper)

    def __init__(cls, name, superclasses, attributes):
        super().__init__(name, superclasses, attributes)

        annotations = getattr(cls, '__annotations__', {})
        for name, a_cls in annotations.items():
            if inspect.isclass(a_cls) and issubclass(a_cls, cls.BASE_INTERFACES):
                setattr(cls, name, StorageMapper(mapper_cls=a_cls, name=name))


class StorageMapper:
    def __init__(self, mapper_cls: Union[Type[BaseMapper], Type[BaseGateway]], name: str = ''):
        self.mapper_cls = mapper_cls
        self.name = name

    def __set_name__(self, storage: StorageBase, name: str) -> None:
        self.name = name

    def __get__(self, storage: StorageBase, storage_cls: Type[StorageBase]) -> StorageBase:
        if storage is None:
            return self

        mapper = self.mapper_cls(storage.conn, storage.logger)
        setattr(storage, self.name, mapper)
        return getattr(storage, self.name)


class StorageBase(metaclass=StorageAnnotatedMeta):
    """
    Storage обеспечивает доступ к данным приложения в рамках одного соединения с БД.

    За составление/выполнение запросов отвечают отдельные мапперы, внедряемые в сторедж через аннотации типов.
    Например:
    >>> class MyStorage(StorageBase):
    >>>     entity: EntityMapper

    После чего можно будет извлекать и соханять объекты Entity используя `storage.entity` интерфейс.
    """

    def __init__(
        self,
        connection: SAConnection,
        transaction: Optional[Transaction] = None,
        logger: Logger = default_logger,
    ):
        self.conn = connection
        self.logger = logger

        self._transaction = transaction

    def __getitem__(self, item: str) -> Any:
        return getattr(self, item)

    def __hash__(self) -> int:
        return hash(self.__class__.__name__)

    def __eq__(self, other: Any) -> bool:
        return isinstance(other, type(self))

    async def commit(self) -> None:
        if self._transaction is None:
            raise RuntimeError('Cannot commit without transaction')
        await self._transaction.commit()


class StorageContextBase(Generic[T2]):
    """
    StorageContextBase - контекстный менеджер, позволяющий получать сразу
    наследника StorageBase с новым подключением и опциональным открытием
    транзакции.

    Для использования необходимо унаследовать новый класс и переопределить
    STORAGE_CLS
    """
    STORAGE_CLS: Type[T2]

    def __init__(self,
                 db_engine: EngineUnion, *,
                 transact: bool = False,
                 logger: Union[LoggerContext, Logger] = default_logger,
                 conn: Optional[SAConnection] = None,
                 transaction_timeout: Optional[timedelta] = None):
        self.db_engine = db_engine
        self.transact = transact
        self._logger = logger
        self._conn = conn
        self._conn_context: Optional[Union[Any, AsyncContextManager[Any]]] = None
        self._transaction_context: Optional[_TransactionContextManager] = None
        self._transaction: Optional[Transaction] = None
        self._transaction_timeout = transaction_timeout

    async def __aenter__(self, engine_name: Optional[str] = None) -> T2:
        if self._conn is None:
            self._conn_context = self.db_engine.acquire()  # type: ignore
            assert self._conn_context

            self._conn = await self._conn_context.__aenter__()
            assert self._conn

        if self.transact:
            try:
                self._transaction_context = self._conn.begin_nested()
                assert self._transaction_context
                self._transaction = await self._transaction_context.__aenter__()
                assert self._transaction
                if self._transaction_timeout:
                    total_milliseconds = int(self._transaction_timeout.total_seconds() * 1000)
                    await self._conn.execute(f"set idle_in_transaction_session_timeout = '{total_milliseconds}ms'")
            except (asyncio.CancelledError, Exception):
                if self._conn_context:
                    await self._conn_context.__aexit__(*sys.exc_info())
                    self._conn_context = None
                    self._conn = None
                raise

        return self.STORAGE_CLS(self._conn, transaction=self._transaction, logger=self._logger)

    async def __aexit__(self, exc_type: Type[Exception], exc: Exception, tb: TracebackType) -> None:
        try:
            if self._transaction_context is not None and self._transaction and self._transaction.is_active:
                await self._transaction_context.__aexit__(exc_type, exc, tb)
            self._transaction_context = None
        finally:
            if self._conn_context:
                await self._conn_context.__aexit__(exc_type, exc, tb)
                self._conn_context = None
                self._conn = None
