import logging
import time
from dataclasses import dataclass, astuple, asdict, fields
from datetime import timedelta
from typing import Any

import infra.yasm.yasmapi
from mail.husky.stages.worker.util.map_dataclass import MapDataclass

log = logging.getLogger(__name__)


@dataclass
class YasmSignals(MapDataclass):
    io_usage_primary: Any
    io_usage_replica: Any
    net_usage_primary: Any
    net_usage_replica: Any
    avg_query_primary: Any
    avg_query_replica: Any
    replication_lag: Any


@dataclass
class YasmSignalValues(YasmSignals):
    io_usage_primary: float = None
    io_usage_replica: float = None
    net_usage_primary: float = None
    net_usage_replica: float = None
    avg_query_primary: timedelta = None
    avg_query_replica: timedelta = None
    replication_lag: timedelta = None

    @staticmethod
    def from_golovan(avg_query_primary: float, avg_query_replica: float, replication_lag: float, **kwargs):
        return YasmSignalValues(
            avg_query_primary=timedelta(milliseconds=avg_query_primary),
            avg_query_replica=timedelta(milliseconds=avg_query_replica),
            replication_lag=timedelta(seconds=replication_lag),
            **kwargs
        )


# Check that there are no typos and YasmSignalValues just specifies fields types, not defining new fields
assert [f.name for f in fields(YasmSignals)] == [f.name for f in fields(YasmSignalValues)]


class YasmAdaptor:
    @staticmethod
    def io_usage_signal(tier, cluster_id):
        return (
            'itype=mdbdom0;'
            f'tier={tier};'
            f'ctype={cluster_id}:'
            'div('
            'sum(portoinst-io_read_fs_bytes_tmmv, portoinst-io_write_fs_bytes_tmmv),'
            'portoinst-io_limit_bytes_tmmv'
            ')'
        )

    @staticmethod
    def net_usage_signal(tier, cluster_id):
        return (
            'itype=mdbdom0;'
            f'tier={tier};'
            f'ctype={cluster_id}:'
            'div(portoinst-net_mb_summ, portoinst-net_guarantee_mb_summ)'
        )

    @staticmethod
    def avg_query_signal(tier, cluster_id):
        return (
            'itype=mailpostgresql;'
            f'tier={tier};'
            f'ctype={cluster_id}:'
            'push-pooler-avg_query_time_vmmv'
        )

    @staticmethod
    def replication_lag_signal(cluster_id):
        return (
            'itype=mailpostgresql;'
            f'tier=replica;'
            f'ctype={cluster_id}:'
            'push-postgres-replication_lag_tmmx'
        )

    def __init__(self, cluster_id, is_in_cloud):
        self.cluster_id = cluster_id
        self.is_in_cloud = is_in_cloud
        self.signal_names = YasmSignals(
            io_usage_primary=self.io_usage_signal('primary', self.cluster_id),
            io_usage_replica=self.io_usage_signal('replica', self.cluster_id),
            net_usage_primary=self.net_usage_signal('primary', self.cluster_id),
            net_usage_replica=self.net_usage_signal('replica', self.cluster_id),
            avg_query_primary=self.avg_query_signal('primary', self.cluster_id),
            avg_query_replica=self.avg_query_signal('replica', self.cluster_id),
            replication_lag=self.replication_lag_signal(self.cluster_id),
        )

    async def get_signals(self) -> YasmSignalValues:
        if not self.is_in_cloud:
            return YasmSignalValues()

        time_to = int(time.time()) - 15
        time_from = time_to - 5
        try:
            golovan_req = infra.yasm.yasmapi.GolovanRequest('CON', 5, time_from, time_to, astuple(self.signal_names))
            _, metrics = next(iter(golovan_req))
            # XXX : Map signals to their local aliases
            return YasmSignalValues.from_golovan(
                **asdict(
                    YasmSignals.map(
                        lambda signal_str: metrics.get(signal_str),
                        self.signal_names,
                    )
                )
            )
        except StopIteration:
            log.error('got empty sequence from yasm')
        except Exception as e:
            log.exception(e)
        return YasmSignalValues()
