from datetime import timedelta
from typing import AsyncIterable, Iterable, Optional, Tuple

from sqlalchemy import Table, and_, exists, func, not_, select, text

from sendr_taskqueue.worker.storage.db.mappers.task import get_task_mapper

from mail.ipa.ipa.core.entities.enums import TaskState, TaskType
from mail.ipa.ipa.core.entities.task import Task
from mail.ipa.ipa.storage.db.tables import metadata
from mail.ipa.ipa.storage.db.tables import organizations as t_organizations
from mail.ipa.ipa.storage.db.tables import tasks as t_tasks
from mail.ipa.ipa.storage.db.tables import users as t_users

BaseTaskMapper = get_task_mapper(metadata=metadata,
                                 task_type_cls=TaskType,
                                 task_cls=Task,
                                 t_tasks=t_tasks
                                 )


class TaskMapper(BaseTaskMapper):  # type: ignore
    async def _get_entity_task_for_work(self,
                                        for_organization: bool,
                                        task_types: Iterable[TaskType],
                                        same_metatask_processing_limit: Optional[int] = None
                                        ) -> Task:
        if for_organization:
            entity_id_column = t_organizations.c.org_id
            t_entity = t_organizations
        else:
            entity_id_column = t_users.c.user_id
            t_entity = t_users

        t_tasks_self = t_tasks.alias('tasks_self')

        from_clause = t_tasks.join(t_entity, entity_id_column == t_tasks.c.entity_id)
        with_for_update_tables: Tuple[Table, ...] = (t_tasks, t_entity)

        # Это условие обрабатывает два случая
        # 1. Есть задача для этой же сущности, которое находится в processing
        # 2. Есть pending задача, у которой task_id меньше, чем у текущей. Такое может быть,
        # если run_at у этой задачи находится в будущем времени (т.е. задача ушла на ретрай)
        entity_has_precending_unfinished_task = exists(
            select((1,)).select_from(t_tasks_self).where(
                and_(
                    t_tasks_self.c.task_type.in_(task_types),
                    t_tasks_self.c.entity_id == t_tasks.c.entity_id,
                    t_tasks_self.c.state.in_(Task.NONTERMINAL_TASK_STATES),
                    t_tasks_self.c.task_id < t_tasks.c.task_id
                ),
            )
        )

        where_clause = and_(
            t_tasks.c.state == TaskState.PENDING,
            t_tasks.c.run_at <= func.now(),
            t_tasks.c.task_type.in_(task_types),
            not_(entity_has_precending_unfinished_task)
        )
        if for_organization:
            t_org_tasks = t_tasks.alias('org_tasks')
            t_subtasks = t_tasks.alias('subtasks')

            # Мета-задачу нельзя брать, если у этой же организации существует задача, которая
            # ещё не завершилась полностью
            # Что такое "завершилась полностью":
            # 1. Она находится в терминальном состоянии
            # 2. Все её дети находятся в терминальном состоянии
            precending_metatask_not_fully_finished = exists(
                select((1,)).select_from(
                    t_org_tasks.join(t_subtasks, t_org_tasks.c.task_id == t_subtasks.c.meta_task_id)
                ).where(
                    and_(
                        t_org_tasks.c.entity_id == t_tasks.c.entity_id,
                        t_org_tasks.c.task_id < t_tasks.c.task_id,
                        t_subtasks.c.state.in_(Task.NONTERMINAL_TASK_STATES),
                    )
                )
            )

            where_clause.append(
                not_(precending_metatask_not_fully_finished)
            )
        elif same_metatask_processing_limit is not None:
            # У нас есть потребность ограничить количество одновременно исполняемых задач per-metatask.
            # Потому что директория не справляется с наплывом запросов.
            # Достигается это так:
            # В момент взятия for_user задачи мы считаем, сколько у неё "братьев", которые находятся в processing.
            # "Брат" = потомок одной и той же metatask
            # Чтобы подсчёт был надёжным, нужно брать lock на метазадачу.
            # Иначе получится так: сразу много воркеров пошли за задачами, "подсчитали" что никакой брат не processing,
            # и все успешно взяли по задаче
            t_meta_task = t_tasks.alias('metatask')
            from_clause = from_clause.join(t_meta_task, t_meta_task.c.task_id == t_tasks.c.meta_task_id)
            with_for_update_tables += (t_meta_task, )

            t_sibling_tasks = t_tasks.alias('sibling_tasks')

            same_metatask_processing_count = select((func.count(),)).select_from(
                t_sibling_tasks,
            ).where(
                and_(
                    t_sibling_tasks.c.meta_task_id == t_tasks.c.meta_task_id,
                    t_sibling_tasks.c.state == TaskState.PROCESSING,
                )
            ).as_scalar()

            where_clause.append(same_metatask_processing_limit > same_metatask_processing_count)

        mapper = self.mapper()
        query = (select(mapper.columns).
                 select_from(from_clause).
                 where(where_clause).
                 with_for_update(of=with_for_update_tables, skip_locked=True, key_share=True).
                 order_by(t_tasks.c.task_id).
                 limit(1))

        row = await self._query_one(query, raise_=self.model.DoesNotExist)
        return mapper(row)

    async def get_org_task_for_work(self, task_types: Iterable[TaskType]) -> Task:
        return await self._get_entity_task_for_work(for_organization=True, task_types=task_types)

    async def get_user_task_for_work(self,
                                     task_types: Iterable[TaskType],
                                     same_metatask_processing_limit: Optional[int] = None) -> Task:
        return await self._get_entity_task_for_work(for_organization=False,
                                                    task_types=task_types,
                                                    same_metatask_processing_limit=same_metatask_processing_limit)

    async def get_org_tasks(self,
                            org_id: int,
                            task_types: Iterable[TaskType],
                            offset: int,
                            limit: int,
                            ) -> AsyncIterable[Task]:
        NONTERMINAL_SUBTASKS_LABEL: str = 'nonterminal_tasks'
        FAILED_SUBTASKS_LABEL: str = 'failed_tasks'

        t_subtasks = t_tasks.alias('subtasks')

        nonterminal_children = (select((func.count(),)).
                                select_from(t_subtasks).
                                where(and_(
                                    t_subtasks.c.meta_task_id == t_tasks.c.task_id,
                                    t_subtasks.c.state.in_(Task.NONTERMINAL_TASK_STATES),
                                )).
                                as_scalar())

        failed_children = (select((func.count(),)).
                           select_from(t_subtasks).
                           where(and_(
                               t_subtasks.c.meta_task_id == t_tasks.c.task_id,
                               t_subtasks.c.state == TaskState.FAILED,
                           )).
                           as_scalar())

        mapper = self.mapper()

        where_clause = and_(
            t_tasks.c.task_type.in_(task_types),
            t_tasks.c.entity_id == org_id,
        )

        columns = mapper.columns + (
            nonterminal_children.label(NONTERMINAL_SUBTASKS_LABEL),
            failed_children.label(FAILED_SUBTASKS_LABEL),
        )
        query = (select(columns).
                 select_from(t_tasks).
                 where(where_clause).
                 order_by(t_tasks.c.created.desc(), t_tasks.c.task_id.desc()).
                 offset(offset).
                 limit(limit))

        async for row in self._query(query):
            task = mapper(row)
            task.nonterminal_children = row[NONTERMINAL_SUBTASKS_LABEL]
            task.failed_children = row[FAILED_SUBTASKS_LABEL]
            yield task

    async def get_org_tasks_count(self, org_id: int, task_types: Iterable[TaskType]) -> int:
        where_clause = and_(
            t_tasks.c.task_type.in_(task_types),
            t_tasks.c.entity_id == org_id,
        )
        query = (select((func.count(),)).
                 select_from(t_tasks).
                 where(where_clause)
                 )
        return await self._query_scalar(query)

    async def count_pending_by_type(self) -> AsyncIterable[Tuple[TaskType, int]]:
        query = (
            select((t_tasks.c.task_type, func.count())).
            select_from(t_tasks).
            where(t_tasks.c.state == TaskState.PENDING).
            group_by(t_tasks.c.task_type)
        )
        async for row in self._query(query):
            yield row[0], row[1]

    async def clean_old_tasks(self, min_age: timedelta, limit: int = 1000) -> None:
        """
        Удаляем сначала субтаски, потом метатаски.
        Если попытаться удалить всё одновременно, то метатаски откажутся удаляться, потому что у них будут слишком
        "молодые" таски.
        """
        seconds: float = min_age.total_seconds()
        t_subtasks = t_tasks.alias('subtasks')
        subtasks = select((1,)).select_from(t_subtasks).where(t_subtasks.c.meta_task_id == t_tasks.c.task_id)

        for state in TaskState:
            if state not in Task.NONTERMINAL_TASK_STATES:
                delete_old_subtasks = (
                    t_tasks.delete().
                    where(
                        t_tasks.c.task_id.in_(
                            select((t_tasks.c.task_id,)).select_from(t_tasks).
                            where(
                                and_(
                                    t_tasks.c.created < func.now().op('-')(text(f"interval '{seconds} seconds'")),
                                    t_tasks.c.state == state,
                                    t_tasks.c.meta_task_id.isnot(None),  # Является ли задача субтаской
                                    not_(exists(subtasks))  # Является ли задача субтаской
                                )
                            ).
                            limit(limit)
                        )
                    )
                )
                await self.conn.execute(delete_old_subtasks)

        for state in TaskState:
            if state not in Task.NONTERMINAL_TASK_STATES:
                delete_old_metatasks_without_subtasks = (
                    t_tasks.delete().
                    where(
                        t_tasks.c.task_id.in_(
                            select((t_tasks.c.task_id,)).select_from(t_tasks).
                            where(
                                and_(
                                    t_tasks.c.created < func.now().op('-')(text(f"interval '{seconds} seconds'")),
                                    t_tasks.c.state == state,
                                    t_tasks.c.meta_task_id.is_(None),  # Является ли задача метатаской
                                    not_(exists(subtasks))  # Можно ли удалять метатаску
                                )
                            ).
                            limit(limit)
                        )
                    )
                )
                await self.conn.execute(delete_old_metatasks_without_subtasks)
