import uuid
from datetime import datetime
from typing import Any, AsyncIterable, Dict, Optional

from sqlalchemy import delete

from sendr_aiopg.query_builder import CRUDQueries, Filters

from mail.payments.payments.core.entities.report import Report
from mail.payments.payments.storage.db.tables import reports as t_reports
from mail.payments.payments.storage.exceptions import ReportNotFound
from mail.payments.payments.storage.mappers.base import BaseMapper
from mail.payments.payments.utils.db import SelectableDataMapper, TableDataDumper


class ReportDataMapper(SelectableDataMapper):
    entity_class = Report
    selectable = t_reports

    def map_data(self, data: Optional[dict]) -> Optional[Dict[str, Any]]:
        if data:
            lower_dt = datetime.fromisoformat(data.pop('lower_dt'))
            upper_dt = datetime.fromisoformat(data.pop('upper_dt'))

            return {
                'lower_dt': lower_dt,
                'upper_dt': upper_dt,
                **data
            }
        return None


class ReportDataDumper(TableDataDumper):
    entity_class = Report
    selectable = t_reports

    def dump_data(self, data: Optional[dict]) -> Optional[Dict[str, Any]]:
        if data:
            lower_dt = data.pop('lower_dt').isoformat()
            upper_dt = data.pop('upper_dt').isoformat()

            return {
                'lower_dt': lower_dt,
                'upper_dt': upper_dt,
                **data
            }
        return None


class ReportMapper(BaseMapper):
    name = 'report'
    _builder = CRUDQueries(
        base=t_reports,
        id_fields=('report_id',),
        dumper_cls=ReportDataDumper,
        mapper_cls=ReportDataMapper,
    )

    async def get(self, report_id: str, uid: Optional[int] = None, for_update: bool = False) -> Report:
        filters = Filters()
        filters.add_not_none('uid', uid)
        query, mapper = self._builder.select(id_values=(report_id,), filters=filters, for_update=for_update)
        return mapper(await self._query_one(query, raise_=ReportNotFound))

    async def find(self,
                   uid: Optional[int] = None,
                   report_id: Optional[str] = None,
                   ) -> AsyncIterable[Report]:
        filters = Filters()
        filters.add_not_none('uid', uid)
        filters.add_not_none('report_id', report_id)
        query, mapper = self._builder.select(filters=filters)
        async for row in self._query(query):
            yield mapper(row)

    async def create(self, report: Report) -> Report:
        report.report_id = str(uuid.uuid4())
        query, mapper = self._builder.insert(report)
        return mapper(await self._query_one(query))

    async def save(self, obj: Report) -> Report:
        query, mapper = self._builder.update(obj, ignore_fields=('report_id', 'uid', 'created', 'data'))
        return mapper(await self._query_one(query))

    async def delete_by_uid(self, uid: int) -> None:
        query = (
            delete(t_reports).
            where(t_reports.c.uid == uid)
        )
        await self.conn.execute(query)
