from datetime import timedelta
from itertools import chain
from typing import Sequence

import ujson as jsn

from aiohttp.web_response import Response
from mail.python.theatre.profiling.typing import Metrics
from mail.python.theatre.roles import Director as DirectorBase
from mail.python.theatre.stages.db_stats.roles.fat_roles_poller import FatRolesPoller
from mail.python.theatre.stages.db_stats.types import DbSignal, Db
from .db_stat_poller import DbStatPoller
from .db_tier_poller import DbTierPoller
from .fat_queries_poller import FatQueriesPoller


class FatPollersConfig:
    def __init__(self, fat_queries_poll_time: timedelta = timedelta(seconds=10), fat_queries_name_normalizer=None):
        self.fat_queries_poll_time = fat_queries_poll_time
        self.fat_queries_name_normalizer = fat_queries_name_normalizer


class Director(DirectorBase):
    def __init__(
        self,
        dbs: Sequence[Db],
        signals: Sequence[DbSignal],
        db_host_status_poll_time: timedelta,
        fatPollersConfig: FatPollersConfig = FatPollersConfig()
    ):
        self.dbs = list(dbs)
        self.signals = list(signals)

        self.db_tier_pollers = [
            DbTierPoller(db, db_host_status_poll_time)
            for db in self.dbs
        ]
        self.db_stat_pollers = [
            DbStatPoller(db=db, db_signal=signal)
            for signal in self.signals
            for db in self.dbs
        ]
        if fatPollersConfig is None:
            self.fat_roles_pollers = []
            self.fat_queries_pollers = []
        else:
            self.fat_roles_pollers = [
                FatRolesPoller(
                    db=db,
                    run_every=fatPollersConfig.fat_queries_poll_time,
                )
                for db in self.dbs
            ]
            self.fat_queries_pollers = [
                FatQueriesPoller(
                    db=db,
                    run_every=fatPollersConfig.fat_queries_poll_time,
                    query_name_normalizer=fatPollersConfig.fat_queries_name_normalizer,
                )
                for db in self.dbs
            ]
        super().__init__(tasks=self.db_tier_pollers
                               + self.db_stat_pollers
                               + self.fat_roles_pollers
                               + self.fat_queries_pollers)

    @property
    def stats(self) -> Metrics:
        return list(chain.from_iterable(
            poller.stats for poller in chain(self.db_stat_pollers, self.fat_roles_pollers, self.fat_queries_pollers)
        ))

    async def unistat_handler(self, _):
        return Response(text=jsn.dumps(self.stats), content_type='application/json')
