from typing import Any, AsyncIterable, Iterable, Optional

import psycopg2
from sqlalchemy import and_, func, select
from sqlalchemy.sql.elements import BooleanClauseList

from sendr_aiopg import BaseMapperCRUD
from sendr_aiopg.query_builder import CRUDQueries

from mail.payments.payments.conf import settings
from mail.payments.payments.core.entities.enums import MerchantOAuthMode
from mail.payments.payments.core.entities.merchant_oauth import MerchantOAuth
from mail.payments.payments.storage.db.tables import merchant_oauths as t_merchant_oauths
from mail.payments.payments.storage.exceptions import MerchantOAuthAlreadyExistsStorageError
from mail.payments.payments.utils.db import SelectableDataMapper, TableDataDumper, create_interval


class MerchantOAuthDataMapper(SelectableDataMapper):
    entity_class = MerchantOAuth
    selectable = t_merchant_oauths


class MerchantOAuthDataDumper(TableDataDumper):
    entity_class = MerchantOAuth
    table = t_merchant_oauths


class MerchantOAuthMapper(BaseMapperCRUD[MerchantOAuth]):
    name = 'merchant_oauths'
    model = MerchantOAuth

    _builder = CRUDQueries(
        base=t_merchant_oauths,
        id_fields=('uid', 'mode'),  # TODO: PAYBACK-670 поменять mode на shop_id
        mapper_cls=MerchantOAuthDataMapper,
        dumper_cls=MerchantOAuthDataDumper,
    )

    async def create(self, obj: MerchantOAuth, *args: Any, **kwargs: Any) -> MerchantOAuth:
        obj.created = obj.updated = func.now()
        try:
            return await super().create(obj, *args, **kwargs)
        except psycopg2.errors.UniqueViolation:
            raise MerchantOAuthAlreadyExistsStorageError

    async def find_by_uid(self, uid: int) -> AsyncIterable[MerchantOAuth]:
        query, mapper = self._builder.select(filters={'uid': uid})

        async for row in self._query(query):
            yield mapper(row)

    async def get_by_shop_id(self, uid: int, shop_id: int) -> MerchantOAuth:
        # TODO: PAYBACK-670 переименовать в get - теперь primary key = (uid, shop_id).
        # Тело переписать, чтобы был простой get по ключу.
        async for oauth in self.find_by_uid(uid=uid):
            if oauth.shop_id == shop_id:
                return oauth
        raise MerchantOAuth.DoesNotExist

    async def get(self, uid: int, mode: MerchantOAuthMode) -> MerchantOAuth:
        # TODO: PAYBACK-670 удалить. Теперь mode - не часть ключа.
        return await super().get(uid, mode)

    async def save(self, obj: MerchantOAuth, ignore_fields: Optional[Iterable[str]] = None) -> MerchantOAuth:
        obj.updated = func.now()
        if ignore_fields is None:
            ignore_fields = set()
        ignore_fields = set(ignore_fields)
        ignore_fields.add('created')
        return await super().save(obj, ignore_fields=ignore_fields)

    def _for_refresh_condition(self, safe: bool = True) -> BooleanClauseList:
        threshold = settings.MERCHANT_OAUTH_REFRESH_THRESHOLD
        updated_threshold = settings.MERCHANT_OAUTH_REFRESH_UPDATED_THRESHOLD

        conds = [t_merchant_oauths.c.expires <= func.now() + create_interval(threshold),
                 t_merchant_oauths.c.poll.is_(True)]
        if safe:
            conds.append(t_merchant_oauths.c.updated <= func.now() - create_interval(updated_threshold))

        return and_(*conds)

    async def get_for_refresh(self, safe: bool = True, for_update: bool = False) -> MerchantOAuth:
        query, mapper = self._builder.select(for_update=for_update, order=('updated', 'expires'), skip_locked=True)
        query = query.where(self._for_refresh_condition(safe=safe)).limit(1)
        return mapper(await self._query_one(query, raise_=MerchantOAuth.DoesNotExist))

    async def count_for_refresh(self, safe: bool = True) -> int:
        query = (select([func.count()]).
                 select_from(t_merchant_oauths).
                 where(self._for_refresh_condition(safe=safe)))

        row = await self._query_one(query)
        return row[0]
