import re
from datetime import timedelta
from itertools import chain, islice
from typing import Callable, Sequence

from mail.python.theatre.profiling.typing import Metrics
from mail.python.theatre.roles import Director
from mail.python.theatre.stages.db_stats.types import DbSignal, Db, TierRunPolicy
from .db_stat_poller import DbStatPoller, default_signal_name_formatter


class QueryNameNormalizer:
    NAME_RE = re.compile(r'^\s*--\s*(\w+)\s*$', flags=re.MULTILINE)
    NONWORD_RE = re.compile(r'\W')
    UNDERSCORE_RE = re.compile(r'_+')

    @staticmethod
    def meaningful(q: str) -> str:
        """
        Strips first part of multiline query, looks like `select ... [from a] [join b ...] [join c ...]`
        :param q: full query text
        :return: stripped first part of query
        """
        strings = q.split('\n')
        from_seen = False
        for s in strings:
            if from_seen and 'join' not in s:
                return
            yield s
            if 'from' in s.lower():
                from_seen = True

    def pretty(self, query: str) -> str:
        """
        Replaces non-word characters with underscores, forming appropriate name for signal
        :param query:
        :return:
        """
        first = ' '.join(islice(self.meaningful(query), 30))
        first = first[:first.lower().find('where')]
        return self.UNDERSCORE_RE.sub('_', self.NONWORD_RE.sub('_', first)).lower().strip('_')

    def exact_name(self, query: str) -> str:
        match = self.NAME_RE.match(query)
        return match and match.groups(1)[0]

    def __call__(self, query: str) -> str:
        return self.exact_name(query) or self.pretty(query)


class FatQueriesPoller(Director):
    """
    Polls fattest queries (by total time) from pg_stat_statements. The `wakeup_at` method allows to notify timer about
    needed early wakeup, forcing poll at given time.
    """
    TOTAL_TIME_Q = '''
        -- poll_fat_queries_total_time
        SELECT query as query_total_time, total_time
          FROM pg_stat_statements s
          JOIN pg_roles r ON (s.userid = r.oid)
         WHERE NOT EXISTS (
            SELECT 1 WHERE rolname = ANY( %(rolname_blacklist)s::text[] ) OR rolname = 'auth-' || current_database()
         )
      ORDER BY total_time DESC
         LIMIT %(limit)s;
    '''

    DISK_READ_Q = '''
        -- poll_fat_queries_disk_read
        SELECT query as query_disk_read_kb, round(shared_blks_read * 8) as disk_read_kb
          FROM pg_stat_statements s
          JOIN pg_roles r ON (s.userid = r.oid)
         WHERE NOT EXISTS (
            SELECT 1 WHERE rolname = ANY( %(rolname_blacklist)s::text[] ) OR rolname = 'auth-' || current_database()
         )
      ORDER BY shared_blks_read DESC
         LIMIT %(limit)s;
    '''

    def __init__(
        self,
        db: Db,
        run_every: timedelta,
        rolname_blacklist: Sequence[str] = ('postgres', 'monitor'),
        top_queries_select_cnt: int = 30,
        top_queries_signal_cnt: int = 20,
        query_name_normalizer: Callable[[str], str] = None,
    ):
        self._query_name_normalizer = query_name_normalizer or QueryNameNormalizer()
        self._signal_cnt = top_queries_signal_cnt * len(db.hosts)

        def query_to_signal_name_formatter(colname: str, query: str):
            return default_signal_name_formatter(colname, self._query_name_normalizer(query))

        self.pollers = [
            DbStatPoller(
                db=db,
                db_signal=DbSignal(
                    query=self.TOTAL_TIME_Q,
                    row_signal=True,
                    run_every=run_every,
                    sigopt_suffix='deee',
                    tier_run_policy=TierRunPolicy.All,
                    query_args=dict(
                        rolname_blacklist=list(rolname_blacklist),
                        limit=top_queries_select_cnt,
                    )
                ),
                signal_name_formatter=query_to_signal_name_formatter,
            ),
            DbStatPoller(
                db=db,
                db_signal=DbSignal(
                    query=self.DISK_READ_Q,
                    row_signal=True,
                    run_every=run_every,
                    sigopt_suffix='deee',
                    tier_run_policy=TierRunPolicy.All,
                    query_args=dict(
                        rolname_blacklist=list(rolname_blacklist),
                        limit=top_queries_select_cnt,
                    )
                ),
                signal_name_formatter=query_to_signal_name_formatter,
            )
        ]
        super().__init__(tasks=self.pollers)

    @property
    def stats(self) -> Metrics:
        return list(chain.from_iterable(islice(poller.stats, self._signal_cnt) for poller in self.pollers))
