import abc
import logging
import datetime as dt
import requests
import collections

from sandbox.common import context
from sandbox.common.types import task as ctt
from sandbox.common.types import client as ctc

from sandbox.yasandbox.database import mapping

from sandbox.services.base import service as base_service

from . import solomon


logger = logging.getLogger(__name__)

MAIN_METRIC_NAME = "kpi_task_execution"


class Reporter(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self, context):
        self.context = context.setdefault(self.reporter_name, {})

    @property
    def reporter_name(self):
        return self.__class__.__name__

    def report(self):
        logger.debug("Reporting %s", self.reporter_name)
        with context.Timer() as timer:
            self.do_report(timer)
        logger.debug("%s reported %s", self.reporter_name, timer)

    @abc.abstractmethod
    def do_report(self, timer):
        pass


class GeneralTaskExecution(Reporter):
    """
    Calculates product metrics for tasks ran by certain schedulers.
    Those tasks have to create subtasks and wait for them.
    """

    SCHEDULERS = {
        str(ctc.Tag.MULTISLOT): 9900,
        str(ctc.Tag.GENERIC): 9899,
    }

    def _analyze_task(self, task_audit, child_audit, solomon_client):
        iteration = 1
        for i, audit in enumerate(task_audit):
            if audit.status in ctt.Status.Group.FINISH:
                continue

            duration = (task_audit[i + 1].date - audit.date).total_seconds()

            solomon_client.add_sensor(
                "duration", duration,
                labels={"status": audit.status, "iteration": iteration},
                ts=audit.date,
            )

            if audit.status == ctt.Status.WAIT_TASK:
                iteration += 1

                wakeup_duration = (task_audit[i + 1].date - child_audit[-1].date).total_seconds()
                solomon_client.add_sensor(
                    "wakeup_duration", wakeup_duration,
                    ts=audit.date,
                )

    def _analyze_scheduler(self, tag, scheduler_id, solomon_client):
        logger.debug("Analyzing scheduler #%s", scheduler_id)
        last_task_id = self.context.get(tag, 0)

        scheduled_tasks = list(mapping.Task.objects(
            scheduler=scheduler_id,
            id__gt=last_task_id,
            execution__status=ctt.Status.SUCCESS,
            time__created__gt=dt.datetime.utcnow() - dt.timedelta(hours=6),
        ).scalar("id"))

        logger.debug("%s tasks found for scheduler #%s", len(scheduled_tasks), scheduler_id)
        if not scheduled_tasks:
            return

        child_of = dict(mapping.Task.objects(parent_id__in=scheduled_tasks).scalar("parent_id", "id"))
        no_child = [_ for _ in scheduled_tasks if _ not in child_of]
        if no_child:
            logger.warning("No child found for parent tasks %s", no_child)

        task_audit = collections.defaultdict(list)
        for audit in mapping.Audit.objects(task_id__in=child_of.values() + scheduled_tasks).order_by("date"):
            task_audit[audit.task_id].append(audit)

        for task_id in scheduled_tasks:
            if task_id in child_of:
                self._analyze_task(task_audit[task_id], task_audit[child_of[task_id]], solomon_client)

        self.context[tag] = max(scheduled_tasks)

    def do_report(self, timer):
        for tag, scheduler_id in self.SCHEDULERS.items():
            solomon_client = solomon.SolomonClient(common_labels={"tag": tag, "metric": MAIN_METRIC_NAME})
            with timer[tag]:
                self._analyze_scheduler(tag, scheduler_id, solomon_client)

                try:
                    solomon_client.send_data()
                except requests.HTTPError as error:
                    logger.exception("Error while pushing to Solomon: %s", error)


class SandboxBundleBuild(Reporter):
    TEST_TAGS = {
        "TEST_SANDBOX": ("SANDBOX::TEST-BUILD_SANDBOX", "JOB-ID:TEST-SANDBOX", "CI:COMMIT"),
        "BUILD_SANDBOX_TASKS": ("BUILD_SANDBOX_TASKS:BINARY", "CI:COMMIT"),
        "BUILD_SANDBOX_TASKS_RAW": ("BUILD_SANDBOX_TASKS:RAW", "CI:COMMIT"),
    }
    FINISH_STATUSES = [ctt.Status.SUCCESS, ctt.Status.FAILURE] + list(ctt.Status.Group.BREAK)

    def _analyze_test(self, test_name, solomon_client):
        logger.debug("Analyzing test %s", test_name)
        last_task_id = self.context.get(test_name, 0)

        task_ids = list(mapping.Task.objects(
            id__gt=last_task_id,
            execution__status__in=self.FINISH_STATUSES,
            tags__all=self.TEST_TAGS[test_name],
            time__created__gt=dt.datetime.utcnow() - dt.timedelta(hours=6),
        ).scalar("id"))

        logger.debug("%s tasks found for test %s", len(task_ids), test_name)
        if not task_ids:
            return

        task_audit_times = collections.defaultdict(list)
        for task_id, audit_time in mapping.Audit.objects(task_id__in=task_ids).fast_scalar("task_id", "date"):
            task_audit_times[task_id].append(audit_time)

        for task_id in task_ids:
            start = min(task_audit_times[task_id])
            finish = max(task_audit_times[task_id])
            duration = (finish - start).total_seconds()
            solomon_client.add_sensor(
                "execution_duration", duration,
                labels={"test_name": test_name},
                ts=start,
            )

        self.context[test_name] = max(task_ids)

    def do_report(self, timer):

        solomon_client = solomon.SolomonClient(common_labels={"metric": MAIN_METRIC_NAME})
        for test_name in self.TEST_TAGS:
            with timer[test_name]:
                self._analyze_test(test_name, solomon_client)

        solomon_client.send_data()


class MetricsReporter(base_service.ThreadedService):

    tick_interval = 60

    @property
    def targets(self):
        return [
            self.Target(GeneralTaskExecution(self.context).report),
            self.Target(SandboxBundleBuild(self.context).report),
        ]
