from typing import Any, AsyncIterable, Iterable, Optional, Tuple

import psycopg2
from sqlalchemy import and_, exists, func, not_, select, update

from sendr_taskqueue.worker.storage import TaskState, get_task_mapper

from mail.beagle.beagle.core.entities.enums import TaskType
from mail.beagle.beagle.core.entities.task import Task
from mail.beagle.beagle.storage.db.tables import metadata
from mail.beagle.beagle.storage.db.tables import organizations as t_organizations
from mail.beagle.beagle.storage.db.tables import tasks as t_tasks
from mail.beagle.beagle.storage.exceptions import OrganizationNotFound

BaseMapper = get_task_mapper(metadata, TaskType, Task, t_tasks)  # type: Any


class TaskMapper(BaseMapper):
    model = Task

    async def count_pending_by_type(self) -> AsyncIterable[Tuple[TaskType, int]]:
        query = (
            select([t_tasks.c.task_type, func.count()]).
            select_from(t_tasks).
            where(t_tasks.c.state == TaskState.PENDING).
            group_by(t_tasks.c.task_type)
        )
        async for row in self._query(query):
            yield row[0], row[1]

    async def create(self, *args: Any, **kwargs: Any) -> Task:
        try:
            return await super().create(*args, **kwargs)
        except psycopg2.errors.ForeignKeyViolation:
            raise OrganizationNotFound

    async def delete_duplicates_by_org(self, task: Task) -> None:
        where_clause = and_(
            t_tasks.c.state == TaskState.PENDING,
            t_tasks.c.task_type == task.task_type,
            t_tasks.c.action_name == task.action_name,
            t_tasks.c.org_id == task.org_id,
            t_tasks.c.task_id != task.task_id
        )

        query = update(t_tasks).where(where_clause).values(state=TaskState.DELETED)
        await self.conn.execute(query)

    async def get_for_work_by_org(self, task_types: Iterable[TaskType],
                                  action_names: Optional[Iterable[str]] = None) -> Task:
        t_tasks_self = t_tasks.alias('tasks_self')

        from_clause = t_tasks.join(t_organizations, t_organizations.c.org_id == t_tasks.c.org_id)

        where_clause = and_(
            t_tasks.c.state == TaskState.PENDING,
            t_tasks.c.run_at <= func.now(),
            t_tasks.c.task_type.in_(task_types),
            not_(exists(
                select((1,)).select_from(t_tasks_self).where(
                    and_(
                        t_tasks_self.c.task_type == t_tasks.c.task_type,
                        t_tasks_self.c.org_id == t_tasks.c.org_id,
                        t_tasks_self.c.state == TaskState.PROCESSING,
                    )
                )
            ))
        )
        if action_names:
            where_clause = and_(where_clause, t_tasks.c.action_name.in_(action_names))

        mapper = self.mapper()
        query = (select(mapper.columns).
                 select_from(from_clause).
                 where(where_clause).
                 with_for_update(of=(t_tasks, t_organizations), skip_locked=True, key_share=True).
                 order_by(t_tasks.c.run_at, t_tasks.c.task_id).
                 limit(1))

        row = await self._query_one(query, raise_=self.model.DoesNotExist)
        return mapper(row)

    async def get_size(self):
        query = select([func.count()]).select_from(t_tasks).where(t_tasks.c.state == TaskState.PENDING)
        db_result = await self.conn.execute(query)
        size = await db_result.scalar()
        return size
