from typing import Any, AsyncIterable, Dict, List, Optional

from sqlalchemy import func

from sendr_aiopg.data_mapper import SelectableDataMapper, TableDataDumper
from sendr_aiopg.query_builder import CRUDQueries, Filters, RelationDescription

from mail.beagle.beagle.core.entities.unit_unit import UnitUnit
from mail.beagle.beagle.storage.db.tables import unit_subscriptions as t_unit_subscriptions
from mail.beagle.beagle.storage.db.tables import unit_units as t_unit_units
from mail.beagle.beagle.storage.db.tables import units as t_units
from mail.beagle.beagle.storage.exceptions import UnitUnitNotFound
from mail.beagle.beagle.storage.mappers.base import BaseMapper
from mail.beagle.beagle.storage.mappers.unit import UnitDataMapper
from mail.beagle.beagle.storage.mappers.unit_subscription import UnitSubscriptionDataMapper


class UnitUnitDataMapper(SelectableDataMapper):
    entity_class = UnitUnit
    selectable = t_unit_units


class UnitUnitDataDumper(TableDataDumper):
    entity_class = UnitUnit
    table = t_unit_units


class UnitUnitMapper(BaseMapper):
    name = 'unit_unit'

    _parent_unit_subscription_relation = RelationDescription(
        name='parent_unit_subscription',
        base=t_unit_units,
        related=t_unit_subscriptions,
        mapper_cls=UnitSubscriptionDataMapper,
        base_cols=('org_id', 'parent_unit_id'),
        related_cols=('org_id', 'unit_id'),
    )

    _unit_relation = RelationDescription(
        name='unit',
        base=t_unit_units,
        related=t_units,
        mapper_cls=UnitDataMapper,
        base_cols=('org_id', 'unit_id'),
        related_cols=('org_id', 'unit_id'),
    )

    _builder = CRUDQueries(
        base=t_unit_units,
        id_fields=('org_id', 'unit_id', 'parent_unit_id'),
        mapper_cls=UnitUnitDataMapper,
        dumper_cls=UnitUnitDataDumper,
        related=(_parent_unit_subscription_relation, _unit_relation)
    )

    @staticmethod
    def _map_related(row, mapper, rel_mappers):
        item: UnitUnit = mapper(row)
        if rel_mappers:
            item.unit = rel_mappers['unit'](row)
        return item

    async def create(self, unit_unit: UnitUnit) -> UnitUnit:
        unit_unit.created = unit_unit.updated = func.now()
        query, mapper = self._builder.insert(unit_unit)
        return mapper(await self._query_one(query))

    async def delete(self, unit_unit: UnitUnit) -> None:
        query = self._builder.delete(unit_unit)
        await self._query_one(query)

    async def find(self,
                   org_id: int,
                   unit_id: Optional[int] = None,
                   unit_ids: Optional[List[int]] = None,
                   parent_unit_id: Optional[int] = None,
                   order_by: Optional[str] = None,
                   parent_unit_subscription_mail_list_id: Optional[int] = None,
                   iterator: bool = False
                   ) -> AsyncIterable[UnitUnit]:
        filters = Filters()
        filters.add_not_none('unit_id', unit_id)
        filters.add_not_none('unit_id', unit_ids, lambda field: field.in_(unit_ids))
        filters.add_not_none('parent_unit_id', parent_unit_id)
        filters.add_not_none('parent_unit_subscription.mail_list_id', parent_unit_subscription_mail_list_id)

        kwargs: Dict[str, Any] = {
            'id_values': (org_id,),
            'filters': filters,
            'order': (order_by,) if order_by else None
        }

        rel_mappers = None
        if parent_unit_subscription_mail_list_id:
            query, mapper, rel_mappers = self._builder.select_related(**kwargs)
        else:
            query, mapper = self._builder.select(**kwargs)

        async for row in self._query(query, iterator=iterator):
            yield self._map_related(row, mapper, rel_mappers)

    async def get(self, org_id: int, unit_id: int, parent_unit_id: int) -> UnitUnit:
        query, mapper = self._builder.select(id_values=(org_id, unit_id, parent_unit_id))
        return mapper(await self._query_one(query, raise_=UnitUnitNotFound))

    async def save(self, unit_unit: UnitUnit) -> UnitUnit:
        unit_unit.updated = func.now()
        query, mapper = self._builder.update(unit_unit, ignore_fields=('created',))
        return mapper(await self._query_one(query, raise_=UnitUnitNotFound))
