# -*- coding: utf-8 -*-

from sandbox.common import rest

from sandbox.sandboxsdk import parameters
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

from sandbox.projects import resource_types
from sandbox.projects.common import file_utils as fu
from sandbox.projects.common import utils

import stats


class MaxReqsDiffPercent(parameters.SandboxFloatParameter):
    name = 'max_diff_percent'
    description = 'Max requests/sec diff (%)'
    default_value = 2.0
    group = 'Checking parameters'


class ImprovementIsDiff(parameters.SandboxBoolParameter):
    name = 'increase_is_diff'
    description = 'Treat improvements as diffs'
    default_value = True
    group = 'Checking parameters'


class MaxMemDiffPercent(parameters.SandboxFloatParameter):
    # Name is NOT precise to improve readability (Reqs/RSS is easy to misread)
    name = 'max_rss_diff_percent'
    description = 'Max RSS diff (%)'
    default_value = 0
    group = 'Checking parameters'


class YabsPerformanceBestCmp(SandboxTask):
    cores = 1
    execution_space = 1024    # We work with task contexts and perf, no heavy resources

    def _get_ctxes(self, pre_result, test_result):
        cl = rest.Client()

        pre_id = self.ctx[pre_result.name]
        test_id = self.ctx[test_result.name]

        pre_task = cl.task[pre_id].read()
        test_task = cl.task[test_id].read()
        if any(task['status'] != 'SUCCESS' for task in [pre_task, test_task]):
            raise SandboxTaskFailureError('Both task should be in SUCCESS status')

        pre_ctx = cl.task[pre_id].context.read()
        test_ctx = cl.task[test_id].context.read()

        return pre_ctx, test_ctx

    def _make_report(self, pre_ctx, test_ctx, description, report_filename='report.txt', rps_ctx_key='requests_per_sec'):
        pre_reqs = _aggregate_reqs(pre_ctx, rps_ctx_key)
        test_reqs = _aggregate_reqs(test_ctx, rps_ctx_key)
        reqs_diff = 100.0 * _diff_reqs(pre_ctx, test_ctx, rps_ctx_key) / pre_reqs
        reqs_thr = self.ctx[MaxReqsDiffPercent.name]

        mem_thr = utils.get_or_default(self.ctx, MaxMemDiffPercent)

        try:
            pre_mem = _aggregate_mem(pre_ctx)
            test_mem = _aggregate_mem(test_ctx)
            mem_diff = 100.0 * _diff_mem(pre_ctx, test_ctx) / pre_mem
        except Exception:
            pre_mem = 0.0
            test_mem = 0.0
            mem_diff = 0.0

        degraded = reqs_diff < -1.0 * reqs_thr
        improved = reqs_diff > reqs_thr

        if mem_thr > 0:
            degraded = degraded or mem_diff > mem_thr
            improved = not degraded and (improved or mem_diff < -1.0 * mem_thr)

        has_diff = degraded

        report_tail = ''
        if utils.get_or_default(self.ctx, ImprovementIsDiff) and improved:
            has_diff = True
            report_tail = (
                "This is a performance/memory usage IMPROVEMENT.\n"
                "If your commit can indeed be expected to increase performance or decrease memory usage of yabs/server,\n"
                "please mark the problem as RESOLVED and do not complain."
            )

        report = (
            "Max requests/sec changed by {:+.1f}% ({} -> {}).\n"
            "RSS increase changed by {:+.2f}% ({:.3f} GiB -> {:.3f} GiB).\n"
            "{}").format(
                reqs_diff, pre_reqs, test_reqs,
                mem_diff, pre_mem, test_mem,
                report_tail
            )

        self.ctx['has_diff'] = has_diff
        self.set_info(report)
        fu.write_file(report_filename, report)

        self.create_resource(
            description=description,
            resource_path=report_filename,
            resource_type=resource_types.PLAIN_TEXT,
        )


def _aggregate_reqs(ctx, rps_ctx_key):
    requests_per_sec = ctx[rps_ctx_key]
    return stats.hodges_lehmann_median(requests_per_sec)


def _diff_reqs(ctx_baseline, ctx_test, rps_ctx_key):
    baseline = ctx_baseline[rps_ctx_key]
    test = ctx_test[rps_ctx_key]
    return stats.hodges_lehmann_diff(baseline, test)


def _aggregate_mem(ctx):
    return float(ctx['rss_stat_increase_hl']) / (1 << 20)


def _diff_mem(ctx_baseline, ctx_test):
    return float(stats.hodges_lehmann_diff(ctx_baseline['rss_stat_increase'], ctx_test['rss_stat_increase'])) / (1 << 20)
