import logging
from decimal import Decimal
from typing import Callable, Optional, Dict, Sequence, List, Any

import psycopg2

from mail.python.theatre.profiling.typing import Metrics, Metric
from mail.python.theatre.roles import Cron
from mail.python.theatre.stages.db_stats.async_qexec import async_qexec
from mail.python.theatre.stages.db_stats.types import DbSignal, Db, DbHost, TierRunPolicy

log = logging.getLogger(__name__)


def convert_value(val):
    if isinstance(val, Decimal):
        return float(val)
    return val


def default_signal_name_formatter(column_name: str, value: str = None) -> str:
    name = f'db_{column_name}'
    if value:
        name += f'_{value}'
    return name


class MetricWithTags:
    def __init__(self, metric: Metric, tags: Dict[str, str] = None):
        self._metric = metric
        self._tags = tags or {}

    @property
    def name(self):
        name, _ = self._metric
        return name

    @property
    def value(self):
        _, value = self._metric
        return value

    def get_tags(self, db: Db, db_host: DbHost):
        tags = {
            'geo': db_host.geo,
            'tier': db_host.tier,
            'ctype': db.ctype,
        }
        if db.prj:
            tags['prj'] = db.prj
        tags.update(self._tags)
        return ";".join((f'{tag}={value}' for tag, value in tags.items()))


MetricsWithTags = List[MetricWithTags]


class DbStatPoller(Cron):
    def __init__(
        self,
        db: Db,
        db_signal: DbSignal,
        # Transforms (column_name) for column-signal or (first_column_name, first_column_value) for row-signal to name of signal
        signal_name_formatter: Callable[[str, Optional[str]], str] = default_signal_name_formatter,
    ):
        self._db = db
        self._db_signal = db_signal
        self._signal_name_formatter = signal_name_formatter

        self._signals: Dict[DbHost, MetricsWithTags] = {}

        super().__init__(job=self._job, run_every=db_signal.run_every)

    @staticmethod
    def tier_policy_hosts(hosts: Sequence[DbHost], policy: TierRunPolicy) -> List[DbHost]:
        hosts = [h for h in hosts if not h.dead]
        if policy == TierRunPolicy.All:
            return list(hosts)
        if policy == TierRunPolicy.Master:
            return [host for host in hosts if host.primary][:1]
        if policy == TierRunPolicy.AnyHost:
            # Replicas first
            return sorted(hosts, key=lambda host: host.primary)[:1]
        raise ValueError(f'Unsupported TierRunPolicy value {policy}')

    @staticmethod
    async def qexec(dsn: str, query: str, query_args: Dict[str, Any]):
        return await async_qexec(dsn, query, query_args)

    @property
    def stats(self) -> Metrics:
        return [
            (f'{signal.get_tags(self._db, db_host)};{signal.name}_{self._db_signal.sigopt_suffix}', signal.value)
            for db_host, host_signals in self._signals.items()
            for signal in host_signals
        ]

    @property
    def signals(self) -> Dict[DbHost, MetricsWithTags]:
        return self._signals

    async def _job(self):
        new_signals: Dict[DbHost, MetricsWithTags] = {}
        hosts = self.tier_policy_hosts(self._db.hosts, self._db_signal.tier_run_policy)
        if not hosts:
            log.warning('No appropriate hosts found for signal %r, candidates: %r', self._db_signal, self._db.hosts)
        for host in hosts:
            try:
                log.debug('Metric queried: %s', self._db_signal.name or self._db_signal.query)
                column_names, rows = await self.qexec(host.dsn, self._db_signal.query, self._db_signal.query_args)
                if self._db_signal.row_signal:
                    signal = self._get_signals_from_row(column_names, rows, self._db_signal.with_tags)
                else:
                    signal = self._get_signals_from_column(column_names, rows)
                log.debug('Metric collected: %d signals', len(signal))
                new_signals[host] = signal
            except psycopg2.Error:
                if self._db_signal.preserve_on_query_fail:
                    new_signals[host] = self._signals[host]

        self._signals = new_signals

    def _get_signals_from_row(self, column_names, rows, with_tags) -> MetricsWithTags:
        return [
            MetricWithTags(
                metric=(
                    self._signal_name_formatter(column_names[0], row[0]),
                    convert_value(row[-1])),
                tags={
                    tag: row[num]
                    for num, tag in enumerate(column_names[1:-1], start=1)
                    if row[num]
                } if with_tags else None
            )
            for row in rows
        ]

    def _get_signals_from_column(self, column_names, rows) -> MetricsWithTags:
        return [
            MetricWithTags(metric=(
                self._signal_name_formatter(col_name),
                convert_value(row[col_num])
            ))
            for col_num, col_name in enumerate(column_names)
            for row in rows
        ]
