from typing import AsyncIterable, Optional

import psycopg2
from sqlalchemy import func

from sendr_aiopg.data_mapper import SelectableDataMapper, TableDataDumper
from sendr_aiopg.query_builder import CRUDQueries, Filters, RelationDescription

from mail.beagle.beagle.core.entities.enums import SubscriptionType
from mail.beagle.beagle.core.entities.user_subscription import UserSubscription
from mail.beagle.beagle.storage.db.tables import mail_lists as t_mail_lists
from mail.beagle.beagle.storage.db.tables import user_subscriptions as t_user_subscriptions
from mail.beagle.beagle.storage.exceptions import UserNotFound, UserSubscriptionAlreadyExists, UserSubscriptionNotFound
from mail.beagle.beagle.storage.mappers.base import BaseMapper
from mail.beagle.beagle.storage.mappers.mail_list import MailListDataMapper


class UserSubscriptionDataMapper(SelectableDataMapper):
    entity_class = UserSubscription
    selectable = t_user_subscriptions


class UserSubscriptionDataDumper(TableDataDumper):
    entity_class = UserSubscription
    table = t_user_subscriptions


class UserSubscriptionMapper(BaseMapper):
    name = 'user_subscription'
    _mail_list_relation = RelationDescription(
        name='mail_list',
        base=t_user_subscriptions,
        related=t_mail_lists,
        mapper_cls=MailListDataMapper,
        base_cols=('org_id', 'mail_list_id'),
        related_cols=('org_id', 'mail_list_id'),
    )
    _builder = CRUDQueries(
        base=t_user_subscriptions,
        id_fields=('org_id', 'mail_list_id', 'uid'),
        mapper_cls=UserSubscriptionDataMapper,
        dumper_cls=UserSubscriptionDataDumper,
        related=(_mail_list_relation,),
    )

    @staticmethod
    def _map_related(row, mapper, rel_mappers):
        user_subscription = mapper(row)
        if rel_mappers:
            user_subscription.mail_list = rel_mappers['mail_list'](row)
        return user_subscription

    async def create(self, user_subscription: UserSubscription) -> UserSubscription:
        user_subscription.created = user_subscription.updated = func.now()
        query, mapper = self._builder.insert(user_subscription)
        try:
            return mapper(await self._query_one(query))
        except psycopg2.errors.UniqueViolation:
            raise UserSubscriptionAlreadyExists
        except psycopg2.errors.ForeignKeyViolation:
            # TODO: extract violated key. Exception also can be thrown when mail list or org not found
            raise UserNotFound

    async def delete(self, user_subscription: UserSubscription) -> None:
        query = self._builder.delete(user_subscription)
        await self._query_one(query)

    async def get(self, org_id: int, mail_list_id: int, uid: int, for_update: bool = False) -> UserSubscription:
        query, mapper = self._builder.select(
            id_values=(org_id, mail_list_id, uid),
            for_update=for_update,
        )
        return mapper(await self._query_one(query, raise_=UserSubscriptionNotFound))

    async def find(self,
                   org_id: int,
                   mail_list_id: Optional[int] = None,
                   uid: Optional[int] = None,
                   subscription_type: Optional[SubscriptionType] = None,
                   limit: Optional[int] = None,
                   offset: Optional[int] = None,
                   order_by: Optional[str] = None,
                   is_deleted: bool = False,
                   iterator: bool = False,
                   ) -> AsyncIterable[UserSubscription]:
        filters = Filters()
        filters.add_not_none('mail_list_id', mail_list_id)
        filters.add_not_none('uid', uid)
        filters.add_not_none('subscription_type', subscription_type)
        filters['mail_list.is_deleted'] = is_deleted
        query, mapper, rel_mappers = self._builder.select_related(
            id_values=(org_id,),
            filters=filters,
            order=(order_by,) if order_by else None,
            offset=offset,
            limit=limit,
        )
        async for row in self._query(query, iterator=iterator):
            yield self._map_related(row, mapper, rel_mappers)

    async def save(self, user_subscription: UserSubscription) -> UserSubscription:
        user_subscription.updated = func.now()
        query, mapper = self._builder.update(user_subscription, ignore_fields=('created',))
        return mapper(await self._query_one(query, raise_=UserSubscriptionNotFound))
