from datetime import timedelta
from typing import Mapping, Any, Iterable
from aiohttp import web

from library.python.monlib.metric_registry import MetricRegistry
from library.python.monlib.encoder import dumps

from mail.pypg.pypg.query_conf import load, QueriesHolder

from mail.python.theatre.roles import Director
from mail.python.theatre.stages.db_stats.roles.db_tier_poller import DbTierPoller
from mail.python.theatre.stages.db_stats.types import DbSignal, TierRunPolicy
from mail.python.theatre.stages.db_stats.roles.db_stat_poller import DbStatPoller

from .dbsources import make_db_source
from .named_signal_mixin import named_signal
from .solomon_registry_mixin import solomon_registry


@solomon_registry
@named_signal
class Poller(DbStatPoller):
    pass


class DbStatDirector(Director):
    def __init__(self, db_stat_pollers: Iterable[Poller], db_tier_pollers: Iterable[DbTierPoller]):
        self.stat_pollers = list(db_stat_pollers)
        self.tier_pollers = list(db_tier_pollers)
        super().__init__(tasks=self.stat_pollers + self.tier_pollers)

    async def solomon_handler(self, request: web.Request):
        CONTENT_TYPE_SPACK = 'application/x-solomon-spack'
        CONTENT_TYPE_JSON = 'application/json'

        registry = MetricRegistry()
        for poller in self.stat_pollers:
            poller.dump_solomon_metrics(registry)

        if request.headers['accept'] == CONTENT_TYPE_SPACK:
            return web.Response(body=dumps(registry), content_type=CONTENT_TYPE_SPACK)

        return web.Response(body=dumps(registry, format='json'), content_type=CONTENT_TYPE_JSON)


def load_query_conf(filename: str) -> QueriesHolder:
    with open(filename) as fd:
        return load(fd.read().split('\n'))


def make_db_stat_director(config: Mapping[str, Any]) -> DbStatDirector:
    dbqueries = config.get('dbsignals') and load_query_conf(config['dbsignals']['query_conf'])

    db_stat_pollers = []
    db_tier_pollers = []

    for dbconfig in config['databases']:
        db = make_db_source(dbconfig)
        db_stat_pollers.extend(
            Poller(
                db=db,
                db_signal=DbSignal(
                    name=signal['name'],
                    query=signal.get('query') or getattr(dbqueries, signal['name']).query,
                    run_every=signal['period'],
                    row_signal=signal.get('row_signal', True),
                    tier_run_policy=TierRunPolicy(signal.get('tier_run_policy', 'any_host')),
                ),
            )
            for signal in dbconfig['queries']
        )
        db_tier_pollers.append(DbTierPoller(db, run_every=config.get('primary_poll_period', timedelta(minutes=1))))

    return DbStatDirector(db_stat_pollers=db_stat_pollers, db_tier_pollers=db_tier_pollers)
