from typing import Any, AsyncIterable, Iterable, Mapping, Optional, Tuple

from sqlalchemy import func, select, update
from sqlalchemy.dialects.postgresql import insert

from sendr_aiopg.query_builder import CRUDQueries, Filters
from sendr_taskqueue.worker.base.entites import BaseTaskParams
from sendr_utils import json_value

from mail.payments.payments.core.entities.enums import APICallbackSignMethod
from mail.payments.payments.core.entities.merchant import APICallbackParams
from mail.payments.payments.core.entities.task import Task, TaskState, TaskType
from mail.payments.payments.storage.db.tables import tasks as t_tasks
from mail.payments.payments.storage.mappers.base import BaseMapper
from mail.payments.payments.utils.db import SelectableDataMapper, TableDataDumper


class TaskDataMapper(SelectableDataMapper):
    entity_class = Task
    selectable = t_tasks


class TaskDataDumper(TableDataDumper):
    entity_class = Task
    table = t_tasks


class TaskMapper(BaseMapper):
    name = 'task'
    model = Task
    _builder = CRUDQueries(
        t_tasks,
        id_fields=('task_id',),
        mapper_cls=TaskDataMapper,
        dumper_cls=TaskDataDumper,
    )

    @staticmethod
    def map(row: Mapping) -> Task:
        callback_params = row['params'].get('callback_params')
        if callback_params is not None:
            row['params']['callback_params'] = APICallbackParams(
                sign_method=APICallbackSignMethod(callback_params.get('sign_method', APICallbackSignMethod.ASYMMETRIC)),
                secret=callback_params.get('secret')
            )

        return Task(
            task_type=row['task_type'],
            task_id=row['task_id'],
            params=row['params'],

            action_name=row['action_name'],
            state=row['state'],
            retries=row['retries'],
            details=row['details'],

            run_at=row['run_at'],
            created=row['created'],
            updated=row['updated'],
        )

    @staticmethod
    def unmap(obj: Task) -> dict:
        return {
            'task_type': obj.task_type,
            'params': (
                obj.params.asdict() if isinstance(obj.params, BaseTaskParams)
                else json_value(obj.params or {})
            ),
            'action_name': obj.action_name,
            'state': obj.state,
            'retries': obj.retries,
            'details': obj.details,
            'run_at': obj.run_at,
        }

    async def count_pending_by_type(self) -> AsyncIterable[Tuple[TaskType, int]]:
        query = (
            select([t_tasks.c.task_type, func.count()]).
            select_from(t_tasks).
            where(t_tasks.c.state == TaskState.PENDING).
            group_by(t_tasks.c.task_type)
        )
        async for row in self._query(query):
            yield row[0], row[1]

    async def count_pending_by_retries(self) -> AsyncIterable[Tuple[int, int]]:
        query = (
            select([t_tasks.c.retries, func.count()]).
            select_from(t_tasks).
            where(t_tasks.c.state == TaskState.PENDING).
            group_by(t_tasks.c.retries)
        )
        async for row in self._query(query):
            yield row[0], row[1]

    async def count_failed_tasks(self, task_type: TaskType, action_name: Optional[str] = None) -> int:
        query = (
            select([func.count()]).
            select_from(t_tasks).
            where(t_tasks.c.state == TaskState.FAILED).
            where(t_tasks.c.task_type == task_type).
            where(t_tasks.c.action_name == action_name)
        )
        row = await self._query_one(query)
        return row[0]

    async def create(self, *args: Any, **kwargs: Any) -> Task:
        if len(args) == 1 and isinstance(args[0], self.model):
            obj = args[0]
        else:
            obj = self.model(*args, **kwargs)

        unmapped = self.unmap(obj)
        query = (
            insert(t_tasks).
            values(
                **unmapped,
                created=func.now(),
                updated=func.now(),
            ).
            returning(*t_tasks.c)
        )
        return self.map(await self._query_one(query))

    async def find(self, task_type: Optional[TaskType] = None) -> AsyncIterable[Task]:
        filters = Filters()
        filters.add_not_none('task_type', task_type)

        query, mapper = self._builder.select(filters=filters)
        async for row in self._query(query):
            yield mapper(row)

    async def get(self, task_id: int, raise_: bool = True) -> Task:
        query = (
            select([t_tasks]).
            where(t_tasks.c.task_id == task_id)
        )
        return self.map(await self._query_one(query, raise_=raise_ and Task.DoesNotExist))

    async def get_for_work(self,
                           task_types: Iterable[TaskType],
                           task_states: Iterable[TaskState],
                           action_names: Optional[Iterable[str]] = None) -> Task:
        _action_names = list(action_names) if action_names is not None else []
        query = (
            select([t_tasks]).
            where(t_tasks.c.task_type.in_(task_types)).
            where(t_tasks.c.state.in_(task_states)).
            where(t_tasks.c.run_at <= func.now()).
            order_by(t_tasks.c.updated).
            limit(1).
            with_for_update(skip_locked=True, key_share=True)
        )
        if _action_names:
            query = query.where(t_tasks.c.action_name.in_(_action_names))
        return self.map(await self._query_one(query, raise_=Task.DoesNotExist))

    async def save(self, obj: Task) -> Task:
        query = (
            update(t_tasks).
            where(t_tasks.c.task_id == obj.task_id).
            values(updated=func.now(), **self.unmap(obj)).
            returning(*t_tasks.c)
        )
        return self.map(await self._query_one(query, raise_=Task.DoesNotExist))
