# coding=utf-8
import json
import logging
from collections import defaultdict

from sandbox import common
from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.types import (
    misc as ctm,
    resource as ctr,
    task as ctt,
)
from sandbox.projects.search_velocity.resources import SearchVelocityAnalyticsLxcImage
from sandbox.sandboxsdk import environments

PERCENTILES = {
    'p25': 25,
    'p50': 50,
    'p75': 75,
    'p95': 95,
    'p99': 99,
    'p99_9': 99.9,
}

SIGNIFICANCE_THRESHOLD = 99.9


class VelocityMetricsContainer(sdk2.parameters.Container):
    resource_type = ("SANDBOX_CI_LXC_IMAGE", SearchVelocityAnalyticsLxcImage)

    @common.utils.classproperty
    def default_value(cls):
        try:
            return sdk2.Resource.find(
                type=SearchVelocityAnalyticsLxcImage,
                state=ctr.State.READY,
                attrs={'released': ctt.ReleaseStatus.STABLE},
            ).first()
        except LookupError:
            return None


class CalcVelocityMetrics(sdk2.Task):
    """
    Рассчитать метрики, актуальные для скорости для двух выборок и построить дифф.

    На вход принимает два ресурса, содержащих данные.
    Каждый ресурс должен содержать файл в формате JSON, где исходные данные сгруппированы по метрикам:

    `{"foo": [2, 3, 4, 5, 6], "bar": [2, 3, 4, 5, 6]}`

    В случае передачи одного ресурса считает агрегации, но не разницы между метриками

    Результаты пишутся в параметр results. Формат выходного ресурса такой:
    ```
    {
        'metric_name1': {
            'p25': {
                'base': 25,
                'actual': 26,
                'diff': {
                    'abs': 1,
                    'percent': 4,
                },
            },
            …
            'p99_9': {…},
            'mean': {…},
            'variance': {…},
            'stddev': {…},
            'mw_test': {
                'less': 0,
                'greater': 100,
                'two_sided': 98,
            }
        },
        …
    }
    `
    """

    class Requirements(sdk2.Requirements):
        cores = 1
        dns = ctm.DnsType.DNS64
        disk_space = 5 * 1024

        environments = (
            environments.PipEnvironment('numpy', '1.12.1', use_wheel=True),
            environments.PipEnvironment('scipy', '0.19.0', use_wheel=True),
        )

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):
        kill_timeout = 60 * 60

        _container = VelocityMetricsContainer('LXC Container', required=True)

        base_metrics_resource = sdk2.parameters.Resource(
            'Resource with base raw metrics log',
            required=True,
        )

        actual_metrics_resource = sdk2.parameters.Resource(
            'Resource with actual raw metrics log',
            required=False,
        )

        with sdk2.parameters.Group('Aggregations') as aggregation_params:
            p25 = sdk2.parameters.Bool('25%', default=True)
            p50 = sdk2.parameters.Bool('50%', default=True)
            p75 = sdk2.parameters.Bool('75%', default=True)
            p95 = sdk2.parameters.Bool('95%', default=True)
            p99 = sdk2.parameters.Bool('99%', default=True)
            p99_9 = sdk2.parameters.Bool('99.9%', default=False)

            mean = sdk2.parameters.Bool('Mean', default=True)

            mw_test = sdk2.parameters.Bool('MW test', default=True)

            variance = sdk2.parameters.Bool('Variance', default=False)
            stddev = sdk2.parameters.Bool('STDDEV', default=True)

        with sdk2.parameters.Output():
            results = sdk2.parameters.JSON('Results')

    def on_execute(self):
        base_resource_data = self._extract_resource_data(self.Parameters.base_metrics_resource)
        base_results = self._calculate_aggregations(base_resource_data)

        actual_results = None
        diff = None

        if self.Parameters.actual_metrics_resource:
            actual_resource_data = self._extract_resource_data(self.Parameters.actual_metrics_resource)
            actual_results = self._calculate_aggregations(actual_resource_data)

            diff = self._calculate_diffs(base_results, actual_results)
            mw_results = self._calculate_mw_test(base_resource_data, actual_resource_data)

            for metric_name, mw_value in mw_results.iteritems():
                diff[metric_name]['mw_test'] = mw_value

        self.Parameters.results = self._group_results(base_results, actual_results, diff)

    def _extract_resource_data(self, metrics_resource):
        import numpy as np

        logging.info('Extracting data for resource %s', metrics_resource)

        resource_data = sdk2.ResourceData(metrics_resource)
        results_file = str(resource_data.path)

        if resource_data.path.is_dir():
            raise TaskFailure('Data resource is a directory but not a file')

        with open(results_file) as fp:
            raw_data = json.load(fp)

        if not isinstance(raw_data, dict):
            raise TaskFailure('Data resource contains not a dict but {}'.format(raw_data.__class__.__name__))

        resource_data = {}
        for metric_name, metric_values in raw_data.iteritems():
            resource_data[metric_name] = np.array(metric_values)

        logging.info('We got data for metrics: %s', resource_data.keys())

        return resource_data

    def _calculate_aggregations(self, resource_data):
        import numpy as np

        p_items = PERCENTILES.items()

        results = defaultdict(dict)

        for metric_name, metric_values in resource_data.iteritems():
            results[metric_name]['hits'] = len(metric_values)

            for name, val in p_items:
                if getattr(self.Parameters, name, False):
                    results[metric_name][name] = np.percentile(metric_values, val)

            if self.Parameters.mean:
                results[metric_name]['mean'] = np.mean(metric_values)

            if self.Parameters.variance:
                results[metric_name]['variance'] = np.var(metric_values)

            if self.Parameters.stddev:
                results[metric_name]['stddev'] = np.std(metric_values)

        return dict(results)

    def _calculate_diffs(self, base_results, actual_results):
        diff = defaultdict(dict)

        for metric_name, base_aggregations in base_results.iteritems():
            actual_aggregations = actual_results.get(metric_name, {})
            if not len(actual_aggregations):
                logging.warn('There is no actual aggregations for %s', metric_name)
                continue

            for aggregation_name, base_val in base_aggregations.iteritems():
                actual_val = actual_aggregations.get(aggregation_name)
                if actual_val is None:
                    logging.warn('There is no aggregation %s for actual values', aggregation_name)
                    continue

                diff[metric_name][aggregation_name] = {
                    'abs': actual_val - base_val,
                    'percent': float(actual_val) * 100.0 / base_val - 100,
                }

        return dict(diff)

    def _calculate_mw_test(self, base_resource_data, actual_resource_data):
        from scipy.stats import mannwhitneyu

        results = {}

        for metric_name, base_metric_values in base_resource_data.iteritems():
            actual_metric_values = actual_resource_data.get(metric_name)
            if actual_metric_values is None or not len(actual_metric_values):
                logging.warn('There is no actual values for %s', metric_name)
                continue

            was_error = False
            try:
                _, p_value_less = mannwhitneyu(base_metric_values, actual_metric_values, alternative='less')
                _, p_value_two = mannwhitneyu(base_metric_values, actual_metric_values, alternative='two-sided')

                less_val = p_value_less * 100
                greater_val = 100 - less_val
                two_val = 100 - p_value_two * 100
            except ValueError as e:
                # На некоторых выборках (например, идентичных) можно получить ValueError
                logging.warn('MW test error: %s', e)

                was_error = True
                less_val = 0
                greater_val = 0
                two_val = 0

            results[metric_name] = {
                'less': less_val,
                'greater': greater_val,
                'two_sided': two_val,
                'error': was_error,
            }

        return results

    def _group_results(self, base_results, actual_results, diff):
        if actual_results is None:
            actual_results = {}
        if diff is None:
            diff = {}

        results = defaultdict(dict)

        for metric_name, aggregations in base_results.iteritems():
            for aggregation_name, aggregation_value in aggregations.iteritems():
                results[metric_name][aggregation_name] = {
                    'base': aggregation_value,
                    'actual': actual_results.get(metric_name, {}).get(aggregation_name),
                    'diff': diff.get(metric_name, {}).get(aggregation_name),
                }

            mw_test = diff.get(metric_name, {}).get('mw_test', {})
            results[metric_name]['mw_test'] = mw_test

            if mw_test:
                # Нужно приведение из numpy.bool_ в bool из-за ошибок сериализации
                results[metric_name]['is_significant'] = bool(mw_test['two_sided'] >= SIGNIFICANCE_THRESHOLD)
            else:
                results[metric_name]['is_significant'] = True

        return dict(results)
