from typing import Any, Iterable, Optional, Type

import sqlalchemy as sa
from sqlalchemy import func

from sendr_aiopg import BaseMapperCRUD, CRUDQueries
from sendr_aiopg.data_mapper import SelectableDataMapper, TableDataDumper
from sendr_aiopg.query_builder import Filters
from sendr_taskqueue.worker.base.entites import BaseTaskType
from sendr_taskqueue.worker.storage.db.entities import Task, TaskState
from sendr_taskqueue.worker.storage.db.tables import get_tasks_table


def get_task_mapper(metadata: sa.MetaData,
                    task_type_cls: Type[BaseTaskType],
                    task_cls: Type[Task] = Task,
                    t_tasks: Optional[sa.Table] = None) -> Type[BaseMapperCRUD[Task]]:
    if t_tasks is None:
        t_tasks = get_tasks_table(metadata, task_type_cls)

    class TaskDataMapper(SelectableDataMapper):
        entity_class = task_cls
        selectable = t_tasks

    class TaskDataDumper(TableDataDumper):
        entity_class = task_cls
        table = t_tasks

    class TaskMapper(BaseMapperCRUD[Task]):
        model = task_cls
        mapper = TaskDataMapper
        dumper = TaskDataDumper

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

            self._builder = CRUDQueries(
                t_tasks,
                id_fields=('task_id',),
                mapper_cls=self.mapper,
                dumper_cls=self.dumper,
            )

        async def create(self, *args: Any, **kwargs: Any) -> Task:
            if len(args) == 1 and isinstance(args[0], self.model):
                task = args[0]
            else:
                task = self.model(*args, **kwargs)

            return await super().create(task, ignore_fields=('task_id',))  # type: ignore

        async def save(self, task: Task) -> Task:  # type: ignore
            task.updated = func.now()
            return await super().save(task)

        async def get_for_work(self,
                               task_types: Iterable[BaseTaskType],
                               task_states: Iterable[TaskState],
                               action_names: Optional[Iterable[str]] = None) -> Task:
            if action_names is not None:
                action_names = list(action_names)

            filters = Filters()
            filters.add_not_none('action_name', action_names, lambda field: field.in_(action_names))
            filters['task_type'] = lambda field: field.in_(task_types)
            filters['state'] = lambda field: field.in_(task_states)
            filters['run_at'] = lambda field: field <= func.now()

            query, mapper = self._builder.select(filters=filters,
                                                 limit=1,
                                                 order=('run_at',),
                                                 for_update=True,
                                                 skip_locked=True)

            return mapper(await self._query_one(query, raise_=self.model.DoesNotExist))

    return TaskMapper
