from datetime import datetime, timedelta
from typing import AsyncIterable, Callable, Mapping, Optional, Tuple, Union

from sqlalchemy import collate, func, select, text

from sendr_aiopg.query_builder import CRUDQueries, Filters, RelationDescription, SelectableDataMapper

from mail.ipa.ipa.core.entities.collector import Collector
from mail.ipa.ipa.core.entities.user import User
from mail.ipa.ipa.storage.db.tables import collectors as t_collectors
from mail.ipa.ipa.storage.db.tables import users as t_users
from mail.ipa.ipa.storage.exceptions import CollectorNotFound
from mail.ipa.ipa.storage.mappers.base import BaseMapper
from mail.ipa.ipa.storage.mappers.collector.serialization import CollectorDataDumper, CollectorDataMapper
from mail.ipa.ipa.storage.mappers.user.serialization import UserDataMapper


class CollectorMapper(BaseMapper):
    name = 'collector'

    _user_relation = RelationDescription(
        name='user',
        base=t_collectors,
        related=t_users,
        base_cols=('user_id',),
        related_cols=('user_id',),
        mapper_cls=UserDataMapper,
    )
    _builder = CRUDQueries(
        base=t_collectors,
        id_fields=('collector_id',),
        mapper_cls=CollectorDataMapper,
        dumper_cls=CollectorDataDumper,
        related=(_user_relation,),
    )

    @staticmethod
    def map_related(row: Mapping,
                    mapper: Callable[[Mapping], Collector],
                    rel_mappers: Optional[Mapping[str, SelectableDataMapper]] = None,
                    ) -> Collector:
        collector = mapper(row)
        if rel_mappers:
            if 'user' in rel_mappers:
                user: User = rel_mappers['user'](row)
                assert isinstance(user, User)
                collector.user = user
        return collector

    async def create(self, collector: Collector) -> Collector:
        collector.checked_at = collector.created_at = collector.modified_at = func.now()
        query, mapper = self._builder.insert(collector, ignore_fields=self._builder.id_fields)
        return mapper(await self._query_one(query))

    async def get(self, collector_id: int, with_user: bool = False) -> Collector:
        rel_mappers: Optional[Mapping] = None
        if with_user:
            query, mapper, rel_mappers = self._builder.select_related(id_values=(collector_id,))
        else:
            query, mapper = self._builder.select(id_values=(collector_id,))

        return self.map_related(await self._query_one(query, raise_=CollectorNotFound), mapper, rel_mappers)

    async def delete(self, collector: Collector) -> None:
        query = self._builder.delete(collector)
        await self._query_one(query, raise_=CollectorNotFound)

    async def find_statuses(self, org_id: int) -> AsyncIterable[str]:
        filters = Filters()
        filters['user.org_id'] = org_id

        query, _, _ = self._builder.select_related(filters=filters)
        query = (
            query.
            with_only_columns([t_collectors.c.status]).
            group_by(t_collectors.c.status).
            order_by('status')
        )

        async for row in self._query(query):
            yield row[0]

    async def find(self,
                   org_id: Optional[int] = None,
                   user_id: Optional[int] = None,
                   status: Optional[str] = None,
                   login: Optional[str] = None,
                   ok_status: Optional[bool] = None,
                   order_by: Optional[Union[str, Tuple[str, ...]]] = None,
                   desc: bool = False,
                   limit: Optional[int] = None,
                   offset: Optional[int] = None,
                   created_at_from: Optional[datetime] = None,
                   created_at_to: Optional[datetime] = None,
                   ) -> AsyncIterable[Collector]:
        filters = Filters()
        filters.add_not_none('user.org_id', org_id)
        filters.add_not_none(
            'user.login',
            login,
            lambda f: func.lower(collate(f, "C.UTF-8")).contains(login.lower(), autoescape=True)  # type: ignore
        )
        filters.add_not_none('created_at', created_at_from, lambda field: created_at_from <= field)
        filters.add_not_none('created_at', created_at_to, lambda field: field < created_at_to)
        filters.add_not_none('status', ok_status, lambda field: ok_status == (field == Collector.OK_STATUS))
        filters.add_not_none('status', status)
        filters.add_not_none('user_id', user_id)

        order: Optional[Tuple[str, ...]]
        if order_by is not None:
            if isinstance(order_by, tuple):
                order = order_by
            else:
                order = (f'-{order_by}' if desc else order_by,)
        else:
            order = None

        query, mapper, rel_mappers = self._builder.select_related(
            filters=filters,
            order=order,
            limit=limit,
            offset=offset,
        )

        async for row in self._query(query):
            yield self.map_related(row, mapper, rel_mappers)

    async def save(self, collector: Collector, update_modified_at: bool = True) -> Collector:
        if update_modified_at:
            collector.modified_at = func.now()
        query, mapper = self._builder.update(collector)
        return mapper(await self._query_one(query, raise_=CollectorNotFound))

    async def remove_user_collectors(self, user_id: int) -> None:
        query = (
            t_collectors.
            delete().
            where(t_collectors.c.user_id == user_id)
        )
        await self.conn.execute(query)

    async def get_for_work(self, delay: timedelta) -> Collector:
        filters = Filters()
        filters.add_not_none('enabled', True)
        filters.add_not_none('pop_id', lambda pop_id: pop_id.isnot(None))
        filters.add_range('checked_at', to_=func.now().op('-')(text(f"interval '{delay.total_seconds()} seconds'")))
        query, mapper = self._builder.select(order=('checked_at',),
                                             filters=filters,
                                             for_update=True,
                                             skip_locked=True,
                                             limit=1)

        return mapper(await self._query_one(query, raise_=CollectorNotFound))

    async def get_min_checked_at(self) -> datetime:
        query = (
            select([func.min(t_collectors.c.checked_at)]).
            where(t_collectors.c.pop_id.isnot(None)).
            where(t_collectors.c.enabled).
            select_from(t_collectors)
        )
        return await self._query_scalar(query)
