from contextlib import asynccontextmanager
from typing import AsyncGenerator, Callable, ClassVar, Optional, Type, TypeVar, Union

from sendr_aiopg import StorageContextBase
from sendr_aiopg.action_context import BaseActionDBContext
from sendr_aiopg.engine.lazy import Preset
from sendr_aiopg.storage import StorageBase
from sendr_aiopg.types import EngineWithUsing
from sendr_core import BaseAction

StorageType = TypeVar('StorageType', bound=StorageBase)


class BaseDBAction(BaseAction[BaseActionDBContext[StorageType]]):
    transact: ClassVar[bool] = False
    storage_context_cls: ClassVar[
        Union[
            Type[StorageContextBase[StorageType]],
            Callable[..., StorageContextBase[StorageType]]
        ]
    ]
    db_engine_name: ClassVar[Optional[str]] = None
    # Could be used with pg-pinger from v1.4.0
    allow_replica_read: ClassVar[bool] = False
    allow_connection_reuse: ClassVar[bool] = True

    @property
    def storage(self) -> StorageType:
        assert self.context.storage is not None
        return self.context.storage

    @property
    def db_engine(self) -> EngineWithUsing:
        return self.context.db_engine

    def can_reuse_connection(self) -> bool:
        if self.context.storage is None:
            return False
        if not self.allow_connection_reuse:
            return False
        if self.allow_replica_read:
            return True  # any connection is reusable
        if hasattr(self.storage.conn, 'conn_preset'):
            return self.storage.conn.conn_preset == Preset.MASTER
        return True

    @asynccontextmanager
    async def storage_setter(
        self,
        transact: bool = False,
        reuse_connection: bool = False,
    ) -> AsyncGenerator[StorageType, None]:
        assert getattr(self, 'storage_context_cls'), 'storage_context_cls must be defined'
        conn = None
        if reuse_connection and self.can_reuse_connection():
            conn = self.storage.conn
        engine_param = self.db_engine_name
        if self.allow_replica_read:
            engine_param = Preset.ACTUAL_LOCAL.value
        ctx = self.storage_context_cls(
            self.db_engine.using(engine_param),
            transact=transact,
            conn=conn,
            logger=self.logger,
        )
        prev_storage = self.context.storage
        try:
            async with ctx as storage:
                self.context.storage = storage
                yield storage
        finally:
            self.context.storage = prev_storage

    async def _run(self):
        async with self.storage_setter(transact=self.transact, reuse_connection=True):
            return await super()._run()
