from typing import Optional, Type

from sqlalchemy import update

from sendr_aiopg import BaseMapper, CRUDQueries
from sendr_aiopg.data_mapper import SelectableDataMapper, TableDataDumper
from sendr_utils import utcnow

from mail.ohio.ohio.core.entities.customer import Customer
from mail.ohio.ohio.storage.db.tables import customers as t_customers
from mail.ohio.ohio.storage.exceptions import CustomerNotFoundStorageError


class CustomerDataMapper(SelectableDataMapper):
    entity_class = Customer
    selectable = t_customers


class CustomerDataDumper(TableDataDumper):
    entity_class = Customer
    table = t_customers


class CustomerMapper(BaseMapper[Customer]):
    _builder = CRUDQueries(
        base=t_customers,
        id_fields=('customer_uid',),
        mapper_cls=CustomerDataMapper,
        dumper_cls=CustomerDataDumper,
    )

    def _acquire_serial_query(self, customer_uid: int, column: str) -> int:
        values = {
            column: getattr(t_customers.c, column) + 1,
            'updated': utcnow(),
        }
        return (
            update(t_customers).
            where(t_customers.c.customer_uid == customer_uid).
            values(**values).
            returning(*t_customers.c)
        )

    async def _acquire_serial(self,
                              customer_uid: int,
                              column: str,
                              raise_: Optional[Type[Exception]] = None,
                              ) -> int:
        query = self._acquire_serial_query(customer_uid, column)
        return (await self._query_one(query, raise_=raise_))[column] - 1

    async def acquire_next_order_id(self, customer_uid: int) -> int:
        return await self._acquire_serial(customer_uid, 'next_order_id')

    async def create(self, obj: Customer) -> Customer:
        obj.created = obj.updated = utcnow()
        query, mapper = self._builder.insert(obj)
        return mapper(await self._query_one(query))

    async def get(self, customer_uid: int, for_update: bool = False) -> Customer:
        query, mapper = self._builder.select(id_values=(customer_uid,), for_update=for_update)
        return mapper(await self._query_one(query, raise_=CustomerNotFoundStorageError))


class CustomerSerialMixin(BaseMapper):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._customer_mapper = CustomerMapper(connection=self.conn, logger=self._logger)

    async def acquire_next_order_id(self, customer_uid: int) -> int:
        return await self._customer_mapper.acquire_next_order_id(customer_uid)
