from typing import AsyncIterable, List, Optional

import psycopg2
from sqlalchemy import func

from sendr_aiopg.data_mapper import SelectableDataMapper, TableDataDumper
from sendr_aiopg.query_builder import CRUDQueries, Filters, RelationDescription

from mail.beagle.beagle.core.entities.unit_subscription import UnitSubscription
from mail.beagle.beagle.storage.db.tables import mail_lists as t_mail_lists
from mail.beagle.beagle.storage.db.tables import unit_subscriptions as t_unit_subscriptions
from mail.beagle.beagle.storage.exceptions import UnitNotFound, UnitSubscriptionAlreadyExists, UnitSubscriptionNotFound
from mail.beagle.beagle.storage.mappers.base import BaseMapper
from mail.beagle.beagle.storage.mappers.mail_list import MailListDataMapper


class UnitSubscriptionDataMapper(SelectableDataMapper):
    entity_class = UnitSubscription
    selectable = t_unit_subscriptions


class UnitSubscriptionDataDumper(TableDataDumper):
    entity_class = UnitSubscription
    table = t_unit_subscriptions


class UnitSubscriptionMapper(BaseMapper):
    name = 'unit_subscription'
    _mail_list_relation = RelationDescription(
        name='mail_list',
        base=t_unit_subscriptions,
        related=t_mail_lists,
        mapper_cls=MailListDataMapper,
        base_cols=('org_id', 'mail_list_id'),
        related_cols=('org_id', 'mail_list_id'),
    )
    _builder = CRUDQueries(
        base=t_unit_subscriptions,
        id_fields=('org_id', 'mail_list_id', 'unit_id'),
        mapper_cls=UnitSubscriptionDataMapper,
        dumper_cls=UnitSubscriptionDataDumper,
        related=(_mail_list_relation,),
    )

    @staticmethod
    def _map_related(row, mapper, rel_mappers):
        unit_subscription = mapper(row)
        if rel_mappers:
            unit_subscription.mail_list = rel_mappers['mail_list'](row)
        return unit_subscription

    async def create(self, unit_subscription: UnitSubscription) -> UnitSubscription:
        unit_subscription.created = unit_subscription.updated = func.now()
        query, mapper = self._builder.insert(unit_subscription)
        try:
            return mapper(await self._query_one(query))
        except psycopg2.errors.UniqueViolation:
            raise UnitSubscriptionAlreadyExists
        except psycopg2.errors.ForeignKeyViolation:
            # TODO: extract violated key. Exception also can be thrown when mail list or org not found
            raise UnitNotFound

    async def delete(self, unit_subscription: UnitSubscription) -> None:
        query = self._builder.delete(unit_subscription)
        await self._query_one(query)

    async def find(self,
                   org_id: int,
                   mail_list_id: Optional[int] = None,
                   unit_id: Optional[int] = None,
                   unit_ids: Optional[List[int]] = None,
                   limit: Optional[int] = None,
                   offset: Optional[int] = None,
                   is_deleted: bool = False
                   ) -> AsyncIterable[UnitSubscription]:
        filters = Filters()
        assert unit_id is None or unit_ids is None
        filters.add_not_none('mail_list_id', mail_list_id)
        filters.add_not_none('unit_id', unit_id)
        filters.add_not_none('unit_id', unit_ids, lambda field: field.in_(unit_ids))
        filters['mail_list.is_deleted'] = is_deleted

        query, mapper, rel_mappers = self._builder.select_related(
            id_values=(org_id,),
            filters=filters,
            offset=offset,
            limit=limit,
        )
        async for row in self._query(query):
            yield self._map_related(row, mapper, rel_mappers)

    async def get(self, org_id: int, mail_list_id: int, unit_id: int, for_update: bool = False) -> UnitSubscription:
        query, mapper = self._builder.select(
            id_values=(org_id, mail_list_id, unit_id),
            for_update=for_update,
        )
        return mapper(await self._query_one(query, raise_=UnitSubscriptionNotFound))

    async def save(self, unit_subscription: UnitSubscription) -> UnitSubscription:
        unit_subscription.updated = func.now()
        query, mapper = self._builder.update(unit_subscription, ignore_fields=('created',))
        return mapper(await self._query_one(query, raise_=UnitSubscriptionNotFound))
