import sys
import logging
import threading
import datetime as dt
import collections as cs

from sandbox import common
import sandbox.common.types.task as ctt
import sandbox.common.types.database as ctd
import sandbox.common.types.statistics as ctss

from sandbox.services import base
from sandbox.serviceq import client as qclient
from sandbox.yasandbox.database import mapping

# FIXME: remove this when ServiceQ is deployed as binary
# the reason is ServiceQ server raising exceptions without "sandbox" namespace,
# while in binary servants they are only available via the said namespace.
# as a result, common.joint is unable to re-construct exceptions and throws generic ServerErrors,
# which go uncaught, while they could have become QRedirects instead
sys.modules["serviceq.errors"] = sys.modules["sandbox.serviceq.errors"]

logger = logging.getLogger(__name__)


class DurationCounter(object):
    def __init__(self):
        self.amount = 0
        self.duration = 0
        self.wait_mutex_amount = 0
        self.wait_mutex_duration = 0

    def add(self, duration, wait_mutex=False):
        self.amount += 1
        self.duration += duration
        if wait_mutex:
            self.wait_mutex_amount += 1
            self.wait_mutex_duration += duration


class TasksStatistics(base.SingletonService):
    """
    Service process for pushing tasks statistics to MongoDB (SANDBOX-4756):

    - ENQUEUED statuses (queue statistics from ServiceQ)
    - *ING + WAIT_* statuses (directly from MongoDB secondaries)
    - remaining quota median for owners of last 100 tasks per purpose client tag
    """

    tick_interval = 10

    Worker = cs.namedtuple("Worker", "thread event")

    def __init__(self, *a, **kw):
        super(TasksStatistics, self).__init__(*a, **kw)
        self.workers = {}

    @property
    def worker_settings(self):
        return {
            func.__name__: (func, timeout)
            for func, timeout in (
                (self.fetch_queue_statistics, 60),
                (self.fetch_executing_and_wait, 20),
                (self.fetch_quota_median, 20),
            )
        }

    def fetch_queue_statistics(self, client):
        queue = client.queue(secondary=True)
        semaphore_waiters = set(client.semaphore_waiters([item.task_id for item in queue]))
        data = cs.defaultdict(DurationCounter)
        utcnow = dt.datetime.utcnow()

        for item in queue:
            data[(item.task_info.owner, item.task_info.type, item.priority)].add(
                (utcnow - dt.datetime.utcfromtimestamp(item.task_info.enqueue_time)).total_seconds(),
                wait_mutex=item.task_id in semaphore_waiters
            )
        signals = [
            dict(
                type=ctss.SignalType.QUEUE_STATISTICS,
                date=utcnow,
                timestamp=utcnow,
                owner=owner,
                task_type=task_type,
                priority=priority,
                amount=counter.amount,
                age=counter.duration,
                wait_mutex_amount=counter.wait_mutex_amount,
                wait_mutex_age=counter.wait_mutex_duration,
            )
            for (owner, task_type, priority), counter in data.iteritems()
        ]
        logger.debug("Queue statistics: %d tasks enqueued, %d unique task groups", len(queue), len(data))
        self.signaler.push(signals)

    def fetch_executing_and_wait(self, _):
        statuses = (
            tuple(ctt.Status.Group.EXECUTE) +
            (ctt.Status.ENQUEUING, ctt.Status.RELEASING) +
            tuple(ctt.Status.Group.WAIT)
        )
        with mapping.switch_db(mapping.Task, ctd.ReadPreference.SECONDARY) as Task:
            tasks = list(
                Task.objects(execution__status__in=statuses).only(
                    "owner", "type", "priority", "execution__status"
                )
            )

        utcnow = dt.datetime.utcnow()
        data = cs.Counter()
        counters = cs.Counter()
        for task in tasks:
            data[(task.owner, task.type, task.priority, task.execution.status)] += 1
            counters[task.execution.status] += 1

        signals = [
            dict(
                type=ctss.SignalType.EXECUTION_STATISTICS,
                date=utcnow,
                timestamp=utcnow,
                owner=owner,
                task_type=task_type,
                priority=priority,
                amount=amount,
                status=status
            )
            for (owner, task_type, priority, status), amount in data.iteritems()
        ]
        self.signaler.push(signals)

        if signals:
            table = [["Status", "Tasks count"], None]
            table.extend(
                [status, counters[status]]
                for status in sorted(counters.keys())
            )
            logger.debug("Execution statistics:")
            common.utils.print_table(table, printer=logger.debug)

    def fetch_quota_median(self, client):
        def median(list_):
            list_.sort()
            q, remainder = divmod(len(list_), 2)
            return list_[q] if remainder else sum(list_[q - 1: q + 1]) / 2.0

        data = client.last_quota_remnants()
        utcnow = dt.datetime.utcnow()
        signals = [
            dict(
                type=ctss.SignalType.QUOTA_MEDIAN,
                date=utcnow,
                timestamp=utcnow,
                purpose_tag=tag,
                median=median([v[1] - v[0] for v in values]),
                ratio_median=median([v[0] * 1000 / v[1] for v in values if v[1] > 0]),
            )
            # values is [quota consumption, quota] list for an owner
            for tag, values in data.iteritems()
        ]
        logger.debug("Quota medians: collected for %d tags", len(signals))
        self.signaler.push(signals)

    def __make_worker(self, func, timeout):
        stop_event = threading.Event()

        def wrapper():
            client = qclient.Client()
            while not stop_event.is_set():
                stop_event.wait(timeout)
                try:
                    func(client)
                except Exception:
                    logger.exception("%s: caught exception", func.__name__)

        th = threading.Thread(
            target=wrapper,
            name=func.__name__,
        )
        return self.Worker(thread=th, event=stop_event)

    def tick(self):
        if not self.workers:
            for func_name, (func, timeout) in self.worker_settings.items():
                worker = self.__make_worker(func, timeout)
                self.workers[func_name] = worker
                worker.thread.start()

        for key in self.workers.keys():
            worker = self.workers[key]
            if worker.thread.is_alive():
                continue
            logger.warning("%s: respawning the dead thread", key)
            self.workers[key] = self.__make_worker(*self.worker_settings[key])

    def on_stop(self):
        for _, worker in self.workers.iteritems():
            worker.event.set()
        for key, worker in self.workers.iteritems():
            logger.info("%s: waiting for the stop", key)
            worker.thread.join()
