import logging
from dataclasses import asdict, fields
from datetime import timedelta

from mail.husky.stages.worker.interactions.yasm_adaptor import YasmSignals
from mail.python.theatre.roles import Cron
from .task_dispatcher import TaskDispatcher
from ..interactions.yasm_adaptor import YasmAdaptor
from ..logic.signal_history import SignalHistory
from ..settings.load_regulator import LoadRegulatorSettings

log = logging.getLogger(__name__)


class LoadRegulator(Cron):
    def __init__(
            self,
            yasm_adaptor: YasmAdaptor,
            task_dispatcher: TaskDispatcher,
            settings: LoadRegulatorSettings,
            shard_id: int
    ):
        self.suspend_data = YasmSignals(
            io_usage_primary=SignalHistory(0.5, 0.8),
            io_usage_replica=SignalHistory(0.4, 0.7),
            net_usage_primary=SignalHistory(0.5, 0.8),
            net_usage_replica=SignalHistory(0.4, 0.7),
            avg_query_primary=SignalHistory(timedelta(milliseconds=50), timedelta(milliseconds=150)),
            # avg_query on replicas are begin summed because of sigopt-suffix
            avg_query_replica=SignalHistory(timedelta(milliseconds=100), timedelta(milliseconds=300)),
            replication_lag=SignalHistory(timedelta(seconds=5), timedelta(seconds=50)),
        )

        self.yasm_adaptor = yasm_adaptor
        self.task_dispatcher = task_dispatcher
        self.shard_id = shard_id
        super().__init__(job=self.regulate, **settings.cron.as_dict())

    async def regulate(self):
        signals = await self.yasm_adaptor.get_signals()
        YasmSignals.map(
            lambda adviser, signal: adviser.update_history(signal),
            self.suspend_data,
            signals
        )
        overloaded_signals = {
            signal: adviser for signal, adviser in asdict(self.suspend_data).items() if adviser.is_overloaded()
        }
        if any(overloaded_signals.values()):
            log.warning(f'Shard_id {self.shard_id} is overloaded by following signals: '
                        f'{[signal + "=" + str(adviser.window[-1]) for signal, adviser in overloaded_signals.items()]}')
            self.task_dispatcher.remove_worker()

        underloaded_signals = {
            signal: adviser for signal, adviser in asdict(self.suspend_data).items() if adviser.is_underloaded()
        }
        if len(underloaded_signals) == len(fields(self.suspend_data)):
            log.info(f'Shard_id {self.shard_id} is underloaded')
            self.task_dispatcher.add_worker()
