from datetime import timedelta
from typing import List

import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert
from aiopg.sa import Engine

from .meta import BucketsMetainfo
from .db_engine import DbEngine

from mail.python.theatre.stages.bucket_holder.typing import BucketId


class BucketsQueries:
    def __init__(self, pg: Engine, worker_name: str, meta: BucketsMetainfo):
        self._pg = DbEngine(pg)
        assert worker_name is not None, 'Worker name must not be null'
        self._worker_name = worker_name
        self._meta = meta
        self._meta.metadata.bind = pg

    @staticmethod
    def is_alive(heartbeated: sa.Column, deadline_secs: int):
        return sa.func.now() - heartbeated < timedelta(seconds=deadline_secs)

    async def acquire_buckets(self, limit: int, deadline_secs: int, bucket_ids: List[BucketId]) -> List[BucketId]:
        """Reacquires all `bucket_ids` and acquire up to `limit` total buckets"""
        async with self._pg.acquire() as conn:
            await conn.execute(
                insert(self._meta.t_workers)
                .values({
                    self._meta.t_workers.c_name: self._worker_name,
                    self._meta.t_workers.c_heartbeat: sa.func.now(),
                })
                .on_conflict_do_update(
                    index_elements=[self._meta.t_workers.c_name],
                    set_={self._meta.t_workers.c_heartbeat.name: sa.func.now()},
                )
            )
            changed_buckets = conn.execute(
                self.acquire_buckets_sql(
                    current_bucket_ids=bucket_ids,
                    worker_name=self._worker_name,
                    heartbeat_deadline_secs=deadline_secs,
                    limit=limit,
                )
            )
            return [
                row['bucket_id'] async for row in changed_buckets
                if row['acquired']
            ]

    def acquire_buckets_sql(
        self,
        current_bucket_ids: List[BucketId],
        worker_name: str,
        heartbeat_deadline_secs: int,
        limit: int
    ):
        rb_cte = (
            self._meta.t_buckets
            .select()
            .where(sa.or_(
                self._meta.t_buckets.c_worker_name == worker_name,
                self._meta.t_buckets.c_heartbeat == sa.null(),
                sa.func.now() - self._meta.t_buckets.c_heartbeat > timedelta(seconds=heartbeat_deadline_secs)
            ))
            .order_by(
                # Take buckets already in process by this worker first,
                (self._meta.t_buckets.c_id == sa.any_(current_bucket_ids)).desc(),
                # Then take buckets that was taken by this worker previously
                self._meta.t_buckets.c_worker_name.is_distinct_from(worker_name),
                # Then take buckets that weren't assigned to any worker
                self._meta.t_buckets.c_worker_name.is_(None).desc(),
                # Then take any other buckets
                self._meta.t_buckets.c_id
            )
            .limit(limit)
            .with_for_update(key_share=True, skip_locked=True)
            .cte('ready_buckets')
        )
        ub_cte = (
            self._meta.t_buckets
            .update()
            .values({
                self._meta.t_buckets.c_worker_name: worker_name,
                self._meta.t_buckets.c_heartbeat: sa.func.now()
            })
            .where(rb_cte.c.bucket_id == self._meta.t_buckets.c_id)
            .returning(
                sa.literal(True).label("acquired"),
                self._meta.t_buckets.c_id
            ).cte('updated_buckets')
        )
        rel_cte = (
            self._meta.t_buckets
            .update()
            .values({self._meta.t_buckets.c_heartbeat: sa.null()})
            .where(self._meta.t_buckets.c_worker_name == worker_name)
            .where(self._meta.t_buckets.c_id.notin_(sa.select([ub_cte.c.bucket_id])))
            .returning(
                sa.literal(False).label("acquired"),
                self._meta.t_buckets.c_id
            )
            .cte('released_buckets')
        )
        return ub_cte.select().union(rel_cte.select())

    async def bucket_count(self):
        async with self._pg.acquire() as conn:
            return await conn.scalar(
                sa.select([sa.func.count()])
                .select_from(self._meta.t_buckets)
            )

    async def active_worker_count(self, deadline_secs: int):
        async with self._pg.acquire() as conn:
            return await conn.scalar(
                sa.select([sa.func.count()])
                .select_from(self._meta.t_workers)
                .where(self.is_alive(self._meta.t_workers.c_heartbeat, deadline_secs))
            )

    async def release_buckets(self, deadline_secs: int, time_to_restart_secs: int):
        heartbeat_near_expire = sa.func.now() - timedelta(seconds=deadline_secs - time_to_restart_secs)
        async with self._pg.acquire() as conn:
            await conn.execute(
                self._meta.t_buckets
                .update()
                .values({self._meta.t_buckets.c_heartbeat: heartbeat_near_expire})
                .where(self._meta.t_buckets.c_worker_name == self._worker_name)
            )
            await conn.execute(
                self._meta.t_workers
                .delete()
                .where(self._meta.t_workers.c_name == self._worker_name)
            )
