#!/usr/bin/python

import logging
import os
from sandbox.common import rest
from sandbox.common.proxy import OAuth

from collections import defaultdict

from sandbox.projects.yabs.release.performance import testenv_data


def get_big_sandbox_api_client():
    try:
        # Local sandbox
        # FIXME First check if this is local SB
        with open(os.path.expanduser("~/.sandbox_token")) as tf:
            tk = tf.read().strip()
    except Exception:
        logging.error("Failed to read ~/.sandbox_token, is this really local Sandbox?")
        return rest.Client()
    auth = OAuth(tk)
    return rest.Client(base_url='', auth=auth)  # We do NOT want local Sandbox here, therefore base_url is not None


def get_release_revisions():
    sandbox = get_big_sandbox_api_client()

    response = sandbox.resource.read(type="BS_RELEASE_TAR", attr_name="released", attr_value="stable", order="-id",
                                     limit=100)
    logging.info(response)
    release_revisions = list()
    for resource in response["items"]:
        attributes = resource["attributes"]
        if "base_revision" in attributes:
            release_revisions.append((int(attributes["base_revision"]), attributes["branch"]))
    return dict(release_revisions)


class PerformancePlotter(object):
    def __init__(self, plot_params):
        self.sandbox = get_big_sandbox_api_client()
        self.plot_params = plot_params
        self.plot_base_revision = None
        self._context_cache = {}

    def __download_base_revision(self, min_revision):
        response = self.sandbox.resource.read(
            type="BS_RELEASE_TAR",
            attr_name="released",
            attr_value="stable",
            order="-id",
            limit=100
        )
        logging.info(response)
        self.plot_base_revision = min_revision
        for resource in response["items"]:
            attributes = resource["attributes"]
            if "base_revision" in attributes:
                base_revision = int(attributes["base_revision"])
                if base_revision < min_revision:
                    return
                self.plot_base_revision = base_revision
        logging.info("Find base revision %s" % self.plot_base_revision)

    def __get_base_revision(self, min_revision):
        self.__download_base_revision(min_revision)

    def _get_task_context(self, task_id):
        try:
            return self._context_cache[task_id]
        except KeyError:
            ctx = self.sandbox.task[task_id].context.read()
            self._context_cache[task_id] = ctx
            return ctx

    def __get_values_and_resource(self, points, value_extractor):
        def extract_task_value(task_id):
            try:
                ctx = self._get_task_context(task_id)
            except Exception:
                logging.warning("Failed to get context of task %s", task_id, exc_info=True)
            else:
                try:
                    return value_extractor(ctx)
                except Exception:
                    logging.warning("Failed to extract data form task %s, context:\n%s", task_id, ctx, exc_info=True)
            return None

        return testenv_data.get_values_and_resources(points, extract_task_value)

    def __extract_corrected_points(self, data):
        coeff = 1.0

        corrected_points = dict()
        origin_points = list()
        revisions = sorted(data.keys())
        for revision in revisions:
            new_res, value, old_value = data[revision]
            if old_value is not None:
                # previous point was a resource switch
                coeff *= old_value / value
                logging.info("Coeff changed to %s at revision %s", coeff, revision)

            corrected_points[revision] = value * coeff
            origin_points.append(value)
        return corrected_points

    def __approximate_corrected_points(self, data_dict, all_revisions):
        prev_data = None
        for revision in all_revisions:
            if revision in data_dict:
                prev_data = data_dict[revision]
            elif prev_data is not None:
                data_dict[revision] = prev_data

        return data_dict

    def __get_corrected_test_data(self, test_results, use_base_revision, value_extractor):
        all_revisions = set()
        test_data = dict()
        corrected_test_data = dict()

        for test in test_results:
            points = test_results[test]
            test_data[test] = self.__get_values_and_resource(points, value_extractor)
            corrected_test_data[test] = self.__extract_corrected_points(test_data[test])
            all_revisions.update(corrected_test_data[test].keys())

        if use_base_revision:
            self.__get_base_revision(min(all_revisions))

        all_revisions.update(get_release_revisions().keys())
        all_revisions = sorted(list(all_revisions))
        for test in test_results:
            corrected_test_data[test] = self.__approximate_corrected_points(corrected_test_data[test], all_revisions)
        return corrected_test_data

    def __get_aggregate_value(self, tests, test_to_data, aggregator):
        all_test_by_revision = defaultdict(lambda: [None for _ in tests])

        # here order of tests important for unsymmetrical aggregators:
        for test_idx, test in enumerate(tests):
            data = test_to_data[test]

            base_revision = self.plot_base_revision
            if base_revision is None or base_revision not in data:
                base_revision = min(data.keys())
            base_value = data[base_revision]

            for revision in data:
                all_test_by_revision[revision][test_idx] = data[revision] / base_value

        aggregated = dict()
        for revision in all_test_by_revision:
            all_revision_data = all_test_by_revision[revision]
            if any(value is None for value in all_revision_data):
                logging.warning("At revision %s data for some of %s is missing: %s", revision, tests, all_revision_data)
            else:
                aggregated[revision] = aggregator(all_test_by_revision[revision])
        return aggregated

    def __get_switch_points(self, test_results, value_extractor):
        switch_points = set()
        for test in test_results:
            values_resources = self.__get_values_and_resource(test_results[test], value_extractor)
            switch_revisions = [(revision, values_resources[revision][0])
                                for revision in values_resources if values_resources[revision][0] is not None]
            switch_points.update(switch_revisions)
        return switch_points

    def get_plots_data(self):
        data_list = list()

        for i, subplot_params in enumerate(self.plot_params):
            data_list.append(dict())
            current_dict = data_list[-1]
            test_results = testenv_data.get_test_results(subplot_params.tests, subplot_params.metric)
            switch_points = self.__get_switch_points(test_results, subplot_params.value_extractor)
            corrected_test_data = self.__get_corrected_test_data(
                test_results, subplot_params.use_base_revision, subplot_params.value_extractor
            )
            aggregated = self.__get_aggregate_value(
                subplot_params.tests, corrected_test_data, subplot_params.aggregator
            )

            base_revision = self.plot_base_revision
            if base_revision is None:
                base_revision = 0

            logging.info("base revision %s", base_revision)
            x = sorted(filter(lambda revision: revision >= base_revision, aggregated.keys()))
            y = [aggregated[revision] for revision in x]
            current_dict["x"] = x
            current_dict["y"] = y
            if subplot_params.get_switch_points:
                current_dict["switches"] = switch_points
            current_dict["releases"] = get_release_revisions()

        return data_list
