# -*- coding: utf-8 -*-

import time

from sandbox import sdk2

import sandbox.common.types.misc as ctm


class ExecutionTimeTracker(sdk2.Task):
    """
    Базовый класс для sdk2 задач, считающий время выполнения задачи и подзадач.
    Интервалы времен выполнения задач с одинаковым stage_name объединяются.
    Если включить track_execution_time, будет считаться только время выполнения данной задачи,
    иначе считается время от создания задачи до ее успешного завершения.
    """

    class Context(sdk2.Task.Context):
        execution_time_intervals = {}
        stopwatch = {}

    @property
    def stage_name(self):
        raise NotImplementedError

    @property
    def track_execution_time(self):
        return False

    def stage_time(self, stage):
        return sum(map(
            lambda interval: interval[1] - interval[0],
            self.Context.execution_time_intervals[stage]
        ))

    def stages(self):
        return list(self.Context.execution_time_intervals.keys())

    def on_enqueue(self):
        if self.track_execution_time or self.stage_name not in self.Context.stopwatch:
            self._reset_stopwatch(self.stage_name)

    def on_wait(self, prev_status, status):
        if self.track_execution_time:
            self._add_time(self.stage_name)

    def on_success(self, prev_status):
        self._on_finish()

    def _on_finish(self):
        self._add_time(self.stage_name)

        sub_tasks = self.find()
        for sub_task in sub_tasks:
            sub_task_intervals = sub_task.Context.execution_time_intervals
            if sub_task_intervals is not ctm.NotExists:
                self._add_intervals(sub_task_intervals)

    def _reset_stopwatch(self, stage):
        self.Context.stopwatch[stage] = time.time()

    def _add_time(self, stage):
        self._add_interval(stage, (self.Context.stopwatch[stage], time.time()))
        self._reset_stopwatch(stage)

    def _add_interval(self, stage, interval):
        self._add_execution_intervals(stage, [interval])

    def _add_intervals(self, intervals):
        for stage, interval in intervals.items():
            self._add_execution_intervals(stage, interval)

    def _add_execution_intervals(self, stage, intervals_to_add):
        def add(new_intervals, i, j, intervals_lhs, intervals_rhs):
            while i < len(intervals_lhs) and (j == len(intervals_rhs) or intervals_lhs[i][0] <= intervals_rhs[j][0]):
                interval = intervals_lhs[i]
                while j < len(intervals_rhs) and interval[1] >= intervals_rhs[j][0]:
                    interval[1] = max(interval[1], intervals_rhs[j][1])
                    j += 1
                new_intervals.append(interval)
                i += 1
            return i, j

        if stage not in self.Context.execution_time_intervals:
            self.Context.execution_time_intervals[stage] = []

        intervals = self.Context.execution_time_intervals[stage]

        new_intervals = []
        i, j = 0, 0
        while i < len(intervals) or j < len(intervals_to_add):
            i, j = add(new_intervals, i, j, intervals, intervals_to_add)
            j, i = add(new_intervals, j, i, intervals_to_add, intervals)

        self.Context.execution_time_intervals[stage] = new_intervals
