from datetime import datetime, timedelta
from typing import Any, AsyncIterable, Iterable, Optional, Type

import sqlalchemy as sa
from sqlalchemy import func, update

from sendr_aiopg import BaseMapperCRUD, CRUDQueries
from sendr_aiopg.data_mapper import SelectableDataMapper, TableDataDumper
from sendr_aiopg.query_builder import Filters
from sendr_taskqueue.worker.base.entites import BaseWorkerType
from sendr_taskqueue.worker.storage.db.entities import Worker, WorkerState
from sendr_taskqueue.worker.storage.db.tables import get_workers_table


def get_worker_mapper(metadata: sa.MetaData,
                      worker_type_cls: Type[BaseWorkerType],
                      worker_cls: Type[Worker] = Worker,
                      t_workers: Optional[sa.Table] = None) -> Type[BaseMapperCRUD[Worker]]:
    if t_workers is None:
        t_workers = get_workers_table(metadata, worker_type_cls)

    class WorkerDataMapper(SelectableDataMapper):
        entity_class = worker_cls
        selectable = t_workers

    class WorkerDataDumper(TableDataDumper):
        entity_class = worker_cls
        table = t_workers

    class WorkerMapper(BaseMapperCRUD[Worker]):
        model = worker_cls
        _builder = CRUDQueries(
            t_workers,
            id_fields=('worker_id',),
            mapper_cls=WorkerDataMapper,
            dumper_cls=WorkerDataDumper,
        )

        async def create(self, item: Worker, *args: Any, **kwargs: Any) -> Worker:
            item.startup = func.now()
            return await super().create(item, *args, **kwargs)  # type: ignore

        async def find(self, *,  # type: ignore
                       state: Optional[WorkerState] = None,
                       states: Optional[Iterable[WorkerState]] = None,
                       beat_before: Optional[datetime] = None,
                       beat_after: Optional[datetime] = None,
                       for_update: bool = False,
                       limit: Optional[int] = None,
                       host: Optional[str] = None,
                       ) -> AsyncIterable[Worker]:
            assert any((state, states)) and not all((state, states)), 'must be define state or states'

            filters = Filters()
            filters.add_not_none('state', state)
            filters.add_not_none('state', states, lambda field: field.in_(states))
            filters.add_not_none('heartbeat', beat_after, lambda field: beat_after <= field)
            filters.add_not_none('heartbeat', beat_before, lambda field: field <= beat_before)
            filters.add_not_none('host', host)

            query, mapper = self._builder.select(
                filters=filters,
                limit=limit,
                for_update=for_update,
                skip_locked=True
            )

            async for row in self._query(query):
                yield mapper(row)

        async def heartbeat(self, worker_id: str) -> str:
            assert t_workers is not None

            query = (
                update(t_workers).
                where(t_workers.c.worker_id == worker_id).
                where(t_workers.c.state == WorkerState.RUNNING).
                values(heartbeat=func.now()).
                returning(t_workers.c.worker_id)
            )
            return (await self._query_one(query))['worker_id']

        async def delete_cleaned_up_workers(self, batch_size: int, period_offset: timedelta) -> int:
            before_date = (datetime.utcnow() - period_offset).date().isoformat()
            schema = metadata.schema
            result = await self.conn.execute(
                "WITH deleted_result AS "
                f"(DELETE FROM {schema}.workers w "
                f"WHERE ctid IN (SELECT ctid FROM {schema}.workers w "
                f"WHERE w.state = 'cleanedup' and w.heartbeat < '{before_date}' and w.task_id is null "
                f"LIMIT {batch_size}) RETURNING *) "
                "SELECT count(*) FROM deleted_result;"
            )
            async for row in result:
                return row[0]
            return 0

    return WorkerMapper
