from typing import AsyncIterable, Iterable, Optional

from sqlalchemy import collate

from sendr_aiopg.query_builder import CRUDQueries, Filters, RelationDescription

from mail.payments.payments.core.entities.service import Service
from mail.payments.payments.storage.db.tables import service_clients as t_service_clients
from mail.payments.payments.storage.db.tables import service_merchants as t_service_merchants
from mail.payments.payments.storage.db.tables import services as t_services
from mail.payments.payments.storage.exceptions import ServiceNotFound
from mail.payments.payments.storage.mappers.base import BaseMapper
from mail.payments.payments.storage.mappers.service.serialization import (
    ServiceClientDataMapper, ServiceDataDumper, ServiceDataMapper, ServiceMerchantDataMapper
)
from mail.payments.payments.utils.datetime import utcnow
from mail.payments.payments.utils.helpers import any_not_none


class ServiceMapper(BaseMapper):
    name = 'service'
    _service_client_relation = RelationDescription(
        name='service_client',
        base=t_services,
        related=t_service_clients,
        mapper_cls=ServiceClientDataMapper,
        base_cols=('service_id',),
        related_cols=('service_id',),
    )
    _service_merchant_relation = RelationDescription(
        name='service_merchant',
        base=t_services,
        related=t_service_merchants,
        mapper_cls=ServiceMerchantDataMapper,
        base_cols=('service_id',),
        related_cols=('service_id',),
    )
    _builder_kwargs = dict(
        base=t_services,
        id_fields=('service_id',),
        mapper_cls=ServiceDataMapper,
        dumper_cls=ServiceDataDumper,
    )
    _builder = CRUDQueries(
        **_builder_kwargs,
        related=(_service_client_relation, _service_merchant_relation),
    )
    _builder_client = CRUDQueries(
        **_builder_kwargs,
        related=(_service_client_relation,),
    )
    _builder_merchant = CRUDQueries(
        **_builder_kwargs,
        related=(_service_merchant_relation,),
    )

    @staticmethod
    def _map_related(row, mapper, rel_mappers):
        service = mapper(row)
        if 'service_client' in rel_mappers:
            service.service_client = rel_mappers['service_client'](row)
        if 'service_merchant' in rel_mappers:
            service.service_merchant = rel_mappers['service_merchant'](row)
        return service

    async def create(self, obj: Service) -> Service:
        query, mapper = self._builder.insert(obj, ignore_fields=('service_id', 'created', 'updated'))
        return mapper(await self._query_one(query))

    async def get(self, service_id: int, for_update: bool = False) -> Service:
        query, mapper = self._builder.select(id_values=(service_id,), for_update=for_update)
        return mapper(await self._query_one(query, raise_=ServiceNotFound))

    async def get_by_related(self,
                             service_client_id: Optional[int] = None,
                             service_client_tvm_id: Optional[int] = None,
                             service_merchant_id: Optional[int] = None,
                             ) -> Service:
        load_client = any_not_none(service_client_id, service_client_tvm_id)
        load_merchant = service_merchant_id is not None
        if load_client and load_merchant:
            builder = self._builder
        elif load_client:
            builder = self._builder_client
        elif load_merchant:
            builder = self._builder_merchant
        else:
            raise RuntimeError('No filter condition provided')

        filters = Filters()
        filters.add_not_none('service_client.tvm_id', service_client_tvm_id)
        filters.add_not_none('service_client.service_client_id', service_client_id)
        if service_merchant_id is not None:
            filters['service_merchant.service_merchant_id'] = service_merchant_id
            filters['service_merchant.deleted'] = False

        query, mapper, rel_mappers = builder.select_related(filters=filters)

        row = await self._query_one(query, raise_=ServiceNotFound)
        return self._map_related(row, mapper, rel_mappers)

    async def find(
        self,
        hidden: Optional[bool] = None,
        service_ids: Optional[Iterable[int]] = None,
    ) -> AsyncIterable[Service]:
        filters = Filters()
        filters.add_not_none('hidden', hidden)
        filters.add_not_none('service_id', service_ids, lambda column: column.in_(list(service_ids or [])))
        query, mapper = self._builder.select(filters=filters)
        async for row in self._query(query):
            yield mapper(row)

    async def save(self, obj: Service) -> Service:
        obj.updated = utcnow()
        query, mapper = self._builder.update(obj, ignore_fields=('service_id'))
        return mapper(await self._query_one(query))

    async def find_by_service_merchants(self, service_merchant_uid: int) -> AsyncIterable[Service]:
        filters = Filters()
        filters['service_merchant.uid'] = service_merchant_uid
        query, mapper, _ = self._builder_merchant.select_related(filters=filters)
        query = query \
            .with_only_columns(mapper.columns) \
            .group_by(*mapper.columns) \
            .order_by(collate(t_services.c.name, 'C.UTF-8'))
        async for row in self._query(query):
            yield mapper(row)
