from datetime import timedelta
from itertools import chain, islice
from typing import Sequence

from mail.python.theatre.profiling.typing import Metrics
from mail.python.theatre.roles import Director
from mail.python.theatre.stages.db_stats.types import DbSignal, Db, TierRunPolicy
from .db_stat_poller import DbStatPoller


class FatRolesPoller(Director):
    """
    Polls fattest queries (by total time) from pg_stat_statements. The `wakeup_at` method allows to notify timer about
    needed early wakeup, forcing poll at given time.
    """
    TOTAL_TIME_Q = '''
        -- poll_fat_roles_total_time
        SELECT rolname as role_total_time, sum(total_time) as total_time
          FROM pg_stat_statements s
          JOIN pg_roles r ON (s.userid = r.oid)
          WHERE NOT EXISTS (
            SELECT 1 WHERE rolname = ANY( %(rolname_blacklist)s::text[] )
         )
      GROUP BY 1
      ORDER BY 2 DESC;
    '''

    DISK_READ_Q = '''
        -- poll_fat_roles_disk_read
        SELECT rolname as role_disk_read_kb, round(sum(shared_blks_read) * 8) as disk_read_kb
          FROM pg_stat_statements s
          JOIN pg_roles r ON (s.userid = r.oid)
          WHERE NOT EXISTS (
            SELECT 1 WHERE rolname = ANY( %(rolname_blacklist)s::text[] )
         )
      GROUP BY 1
      ORDER BY 2 DESC;
    '''

    def __init__(
        self,
        db: Db,
        run_every: timedelta,
        rolname_blacklist: Sequence[str] = tuple(),
        top_roles_signal_cnt: int = 10,
    ):
        self._signal_cnt = top_roles_signal_cnt * len(db.hosts)

        self.pollers = [
            DbStatPoller(
                db=db,
                db_signal=DbSignal(
                    query=self.TOTAL_TIME_Q,
                    row_signal=True,
                    run_every=run_every,
                    sigopt_suffix='deee',
                    tier_run_policy=TierRunPolicy.All,
                    query_args=dict(
                        rolname_blacklist=list(rolname_blacklist),
                    )
                ),
            ),
            DbStatPoller(
                db=db,
                db_signal=DbSignal(
                    query=self.DISK_READ_Q,
                    row_signal=True,
                    run_every=run_every,
                    sigopt_suffix='deee',
                    tier_run_policy=TierRunPolicy.All,
                    query_args=dict(
                        rolname_blacklist=list(rolname_blacklist),
                    )
                ),
            )
        ]
        super().__init__(tasks=self.pollers)

    @property
    def stats(self) -> Metrics:
        return list(chain.from_iterable(islice(poller.stats, self._signal_cnt) for poller in self.pollers))
