from datetime import timedelta
from typing import List, Tuple

from aiopg.sa import Engine
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert

from mail.python.theatre.roles import Cron
from mail.python.theatre.stages.bucket_holder.interactions.db.meta import BucketsMetainfo
from mail.python.theatre.stages.bucket_holder.interactions.db.db_engine import DbEngine
from mail.husky.stages.worker.settings.buckets_provider import BucketsProviderSettings
from mail.python.theatre.stages.bucket_holder.typing import make_bucket_id


class BucketsProvider(Cron):
    # This query is not as simple as it deserves to be
    # because of lack of "Loose indexscan" / "index skip scan" / "jump scan"
    # in Postgres. They say "the goal is to add support for it in PostgreSQL 13":
    # https://github.com/jesperpedersen/postgres/tree/indexskipscan
    # The official workaround is as follows:
    # https://wiki.postgresql.org/wiki/Loose_indexscan
    GET_SHARDS_Q = '''
    WITH RECURSIVE buckets AS (
        (
            SELECT shard_id, task_group_id
              FROM transfer.users_in_dogsleds u
             WHERE status in ('pending', 'in_progress')
             ORDER BY shard_id, COALESCE(task_group_id, -1)
             LIMIT 1
        )
        UNION ALL
        (
            WITH t as (select shard_id, task_group_id from buckets)
            SELECT shard_id, task_group_id
              FROM transfer.users_in_dogsleds
             WHERE status in ('pending', 'in_progress')
               AND (shard_id,  COALESCE(task_group_id, -1)) > (select t.shard_id, COALESCE(t.task_group_id, -1) from t)
             ORDER BY shard_id,  COALESCE(task_group_id, -1)
             LIMIT 1
         )
    )
    SELECT DISTINCT shard_id, husky_cluster
    FROM buckets
    LEFT JOIN transfer.task_group tg ON (tg.id = buckets.task_group_id);
    '''

    def __init__(self, huskydb_pg: Engine, huskydb_ro: Engine, meta: BucketsMetainfo, settings: BucketsProviderSettings):
        super().__init__(job=self.update_shards_in_buckets, randomize_start_time=True, **settings.cron.as_dict())
        self._pg = DbEngine(huskydb_pg)
        self._pg_ro = DbEngine(huskydb_ro)
        self._meta = meta
        self._settings = settings

    async def update_shards_in_buckets(self):
        async with self._pg_ro.acquire() as conn_ro:
            async with conn_ro.begin():
                if await conn_ro.scalar('select pg_try_advisory_xact_lock(42)'):
                    shard_id_husky_cluster_pairs = [(row['shard_id'], row['husky_cluster'])
                                                    async for row in conn_ro.execute(self.GET_SHARDS_Q)]
                    async with self._pg.acquire() as conn:
                        if shard_id_husky_cluster_pairs:
                            await conn.execute(self.update_buckets_sql(shard_id_husky_cluster_pairs))
                        await conn.execute(
                            self.delete_old_buckets_sql(
                                self._settings.no_tasks_buckets_deletion_delay
                            )
                        )

    def update_buckets_sql(self, shard_id_husky_cluster_pairs: List[Tuple[int, str]]):
        vals = [
            {self._meta.t_buckets.c_id: make_bucket_id(shard_id=pair[0], husky_cluster=pair[1])}
            for pair in shard_id_husky_cluster_pairs
        ]
        return (
            insert(self._meta.t_buckets)
            .values(vals)
            .on_conflict_do_update(
                index_elements=[self._meta.t_buckets.c_id],
                set_=dict(last_updated=sa.func.now())
            )
        )

    def delete_old_buckets_sql(self, deletion_delay: timedelta):
        return (
            self._meta.t_buckets
            .delete()
            .where(sa.column('last_updated') < sa.func.now() - deletion_delay)
        )
