import asyncio
import logging
import time
from collections import defaultdict
from itertools import chain

from typing import Callable, Any, Coroutine, Dict, Union
from dataclasses import dataclass

from mail.python.theatre.roles.base import Servant
from mail.python.theatre.logging.request_id import set_request_id
from mail.shiva.stages.api.settings.shard_worker import ShardWorkerSettings
from mail.shiva.stages.api.props.shard.task import finish_task, HuskydbEngine
from mail.python.theatre.profiling.hist import Hist
from mail.python.theatre.profiling.typing import Metrics, HistMetric

log = logging.getLogger(__name__)


@dataclass
class Task:
    job: Callable[[Any], Coroutine] = None
    params: Any = None


class PoisonPill:
    pass


class TaskStats:
    def __init__(self, name: str):
        self._buckets = [0, 1, 10, 60, 5*60, 10*60, 20*60, 30*60, 60*60, 2*60*60, 10*60*60]
        self._terminating_bucket_bound = max(self._buckets) * 2
        self._name = name
        self._success_hist_s = Hist(buckets=self._buckets, name=f'{name}_success_s', suffix='ahhh')
        self._fail_hist_s = Hist(buckets=self._buckets, name=f'{name}_fail_s', suffix='ahhh')
        self._plain_meters = defaultdict(int)
        self._hist_meters = dict()

    def ok(self, sec: float):
        self._success_hist_s.update(sec)

    def failed(self, sec: float):
        self._fail_hist_s.update(sec)

    def increase_task_meter(self, name: str, value: Union[float, int] = 1):
        self._plain_meters[f'{self._name}_{name}_ammm'] += value

    def increase_global_meter(self, name: str, value: Union[float, int] = 1):
        self._plain_meters[name] += value

    def decrease_task_meter(self, name: str, value: Union[float, int] = 1):
        self._plain_meters[f'{self._name}_{name}_ammm'] -= value

    def put_in_hist(self, name: str, value: Union[float, int] = 1):
        if name not in self._hist_meters:
            self._hist_meters[name] = Hist(buckets=self._buckets, name=name)
        self._hist_meters[name].update(value)

    def get(self) -> Metrics:
        return [self._add_terminating_bucket(m.get()) for m in [self._success_hist_s, self._fail_hist_s] if m] \
            + [(name, value) for name, value in self._plain_meters.items()] \
            + [self._add_terminating_bucket(hist.get()) for _, hist in self._hist_meters.items()]

    def _add_terminating_bucket(self, hist_metric: HistMetric) -> HistMetric:
        hist_metric[1].append((self._terminating_bucket_bound, 0))
        return hist_metric


class ShardWorker(Servant):
    def __init__(self, settings: ShardWorkerSettings, huskydb: HuskydbEngine, worker_name: str):
        super().__init__(
            routine=self._job,
            servant_count=settings.servant_count
        )
        self._queue = asyncio.Queue(settings.maxsize)
        self.worker_name = worker_name
        self._huskydb = huskydb
        self._stats: Dict[str, TaskStats] = {}

    async def stop(self):
        for _ in range(self._servant_count):
            await self._queue.put(PoisonPill)
        await super().stop()

    async def _process_task(self, item: Task):
        shard = f'{item.params.task_name}_{item.params.shard_id}' if item.params.shard_id else None
        with set_request_id(prefix=item.params.task_name, value=shard):
            try:
                log.info('ShardWorker start <%s> task', item.params.task_name)
                start = time.monotonic()
                await item.job(item.params, self._stats[item.params.task_name])

                log.info('ShardWorker finish <%s> task', item.params.task_name)
                await finish_task(self._huskydb, item.params)
                self._stats[item.params.task_name].ok(time.monotonic() - start)
            except Exception as e:
                log.error('ShardWorker failed with <%s> task', item.params.task_name)
                self._stats[item.params.task_name].failed(time.monotonic() - start)
                log.exception(e)
                await finish_task(self._huskydb, item.params, 'error', str(e))

    async def _job(self):
        while not self._stopped or self._queue.qsize():
            item = await self._queue.get()
            if item is PoisonPill:
                self._queue.task_done()
                return
            await self._process_task(item)
            self._queue.task_done()
            self._stats[item.params.task_name].decrease_task_meter('enqueued')
            await asyncio.sleep(0)

    def put_nowait(self, item: Task):
        if item.params.task_name not in self._stats:
            self._stats[item.params.task_name] = TaskStats(item.params.task_name)
        self._stats[item.params.task_name].increase_task_meter('enqueued')
        self._queue.put_nowait(item)

    def metrics(self) -> Metrics:
        return list(chain.from_iterable(m.get() for m in self._stats.values()))
