# -*- coding: utf-8 -*-

import json
import logging
from copy import deepcopy

from sandbox.projects import resource_types
from sandbox.sandboxsdk import parameters
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.projects.common.app_host.options import AppHostBundle, HttpAdapterBundle, AppHostTestServants, BaseBenchmarkPlan
from sandbox.projects.common.app_host.options import DplannerExecutable, ALL_APPHOST, MINIMAL_APPHOST, HTTP_ADAPTER_BLIST
from sandbox.projects.common.app_host.options import ALL_BENCHMARKS, COMPRESSING_BLIST, SIZE_BLIST, FUZZING_BLIST
from sandbox.projects.common.app_host.options import DELAY_BLIST, BREADTH_BLIST, BREADTH_WITH_DELAY_BLIST, STREAMING_BLIST
from sandbox.projects.common.app_host.options import BALANCING_CONT_BLIST, BALANCING_EFFICIENCY_BLIST, REQUEST_FRACTION_BLIST
from sandbox.projects.common.app_host.options import ALL_BLIST, MIN_BLIST, get_benchmark_keys
from sandbox.projects.common.app_host.options import SUB_FIELDS, SUBHOST_BLIST

from benchmarks import SynthPointGraphBenchmark
from benchmarks import WideGraphBenchmark
from benchmarks import BalancingBenchmark
from benchmarks import RequestFractionBenchmark
from benchmarks import HttpAdapterBasicBenchmark
from benchmarks import StreamingBenchmark
from benchmarks import SubhostBenchmark


class BenchmarkList(parameters.SandboxSelectParameter):
    name = 'benchmark_list'
    description = 'list of benchmarks'
    required = True
    choices = [(key, key) for key, desc in ALL_BENCHMARKS]
    default_value = 'ALL_apphost'
    group = "settings"
    sub_fields = SUB_FIELDS


class Concurrency(parameters.SandboxIntegerParameter):
    name = 'concurrency'
    description = 'number of concurrent threads / sockets'
    required = True
    default_value = 16
    group = "settings"


class TimeMultiplier(parameters.SandboxFloatParameter):
    name = 'time_multiplier'
    description = 'time multiplier (0.5 -- twice faster, 2.0 -- twice slower)'
    required = True
    default_value = 1.0
    group = "settings"


class BenchmarkAppHostAll(SandboxTask):
    type = "BENCHMARK_APP_HOST_ALL"

    execution_space = 4096
    default_cpu_model = None

    input_parameters = [BenchmarkList, Concurrency, TimeMultiplier, AppHostBundle, HttpAdapterBundle,
                        AppHostTestServants, DplannerExecutable, BaseBenchmarkPlan]

    def init_benchmarks(self, benchmark_types):
        benchmarks_list = []

        base_benchmark_plan = self.ctx[BaseBenchmarkPlan.name]
        concurrency = int(self.ctx[Concurrency.name])
        time_multiplier = float(self.ctx.get(TimeMultiplier.name, 1.0))
        self.ctx['abs_path'] = self.abs_path()
        common_resources = {
            "app_host": {
                "type": "app_host",
                "binary": {"id": self.ctx[AppHostBundle.name], "type": "file"}
            },
            "perftest_servant": {
                "type": "perftest_servant",
                "binary": {"id": self.ctx[AppHostTestServants.name], "type": "file"}
            }
        }
        logging.debug("common_resources: {}".format(common_resources))

        # Compressing
        if COMPRESSING_BLIST[0] in benchmark_types:
            for codec in ['null', 'zstd_1', 'zstd_5', 'zstd_6', 'lz4', 'brotli-1']:
                for rsize in [8192]:
                    for use_grpc in [True, False]:
                        benchmarks_list.append(
                            SynthPointGraphBenchmark(
                                desc="Compressing.{}.{}KB, use_grpc = {}".format(codec, rsize / 1024, use_grpc),
                                requests_limit=int(100000 * concurrency * time_multiplier),
                                use_grpc=use_grpc,
                                resp_size_distribution='{},0,0'.format(rsize),
                                source_codecs=[codec],
                                plan_id=base_benchmark_plan,
                                concurrency=concurrency,
                                resources=deepcopy(common_resources)
                            )
                        )

        # Streaming
        if STREAMING_BLIST[0] in benchmark_types:
            benchmarks_list.append(
                StreamingBenchmark(
                    desc="Streaming, use_grpc = True",
                    requests_limit=int(100000 * concurrency * time_multiplier),
                    use_grpc=True,
                    plan_id=base_benchmark_plan,
                    resources=deepcopy(common_resources)
                )
            )

        # Subhost
        if SUBHOST_BLIST[0] in benchmark_types:
            for use_grpc in [True, False]:
                use_subhost_streaming = False
                benchmarks_list.append(
                    SubhostBenchmark(
                        desc="Subhost, use_grpc = {}, use_subhost_streaming = {}".format(use_grpc, use_subhost_streaming),
                        requests_limit=int(100000 * concurrency * time_multiplier),
                        use_grpc=use_grpc,
                        use_subhost_streaming=False,
                        plan_id=base_benchmark_plan,
                        resources=deepcopy(common_resources)
                    )
                )

        # Fuzzing
        if FUZZING_BLIST[0] in benchmark_types:
            for name, tp, stp, ccatp, bsbp, bbc in [
                    ('Trash', 1.0, 0.0, 0.0, 0.0, 0),
                    ('StructuredTrash', 0.0, 1.0, 0.0, 0.0, 0),
                    ('ConsistencyCheckAwareTrash', 0.0, 0.0, 1.0, 0.0, 0),
                    ('BreakSomeBytes.', 0.0, 0.0, 0.0, 1.0, 10)]:
                for use_grpc in [True, False]:
                    benchmarks_list.append(
                        SynthPointGraphBenchmark(
                            desc="Fuzzing.{}.100%.8KB, use_grpc = {}".format(name, use_grpc),
                            requests_limit=int(100000 * concurrency * time_multiplier),
                            use_grpc=use_grpc,
                            trash_prob=tp,
                            structured_trash_prob=stp,
                            consistency_check_aware_trash_prob=ccatp,
                            break_some_bytes_prob=bsbp,
                            broken_byte_count=bbc,
                            resp_size_distribution='8192,0,0',
                            resp_time_distribution='0,0,0',
                            plan_id=base_benchmark_plan,
                            concurrency=concurrency,
                            resources=deepcopy(common_resources)
                        )
                    )

        # Sizes
        if SIZE_BLIST[0] in benchmark_types:
            data = [(1024, 100000), (8192, 50000), (32768, 25000), (262144, 6000)]
            for rsize, reqs in data:
                for use_grpc in [True, False]:
                    benchmarks_list.append(
                        SynthPointGraphBenchmark(
                            desc="Size.{}KB, use_grpc = {}".format(rsize / 1024, use_grpc),
                            requests_limit=int(reqs * concurrency * time_multiplier),
                            use_grpc=use_grpc,
                            resp_size_distribution='{},0,0'.format(rsize),
                            resp_time_distribution='0,0,0',
                            plan_id=base_benchmark_plan,
                            concurrency=concurrency,
                            resources=deepcopy(common_resources)
                        )
                    )

        # Delays
        if DELAY_BLIST[0] in benchmark_types:
            for delay in [10, 25, 50]:
                rlim = int(62500 * concurrency * time_multiplier / delay)
                for use_grpc in [True, False]:

                    benchmarks_list.append(
                        SynthPointGraphBenchmark(
                            requests_limit=rlim,
                            desc="Delay.{}ms,1KB, use_grpc = {}".format(delay, use_grpc),
                            use_grpc=use_grpc,
                            resp_size_distribution='1024,0,0',
                            resp_time_distribution='{},0,0'.format(delay),
                            plan_id=base_benchmark_plan,
                            concurrency=concurrency,
                            resources=deepcopy(common_resources)
                        )
                    )

        # Breadth
        if BREADTH_BLIST[0] in benchmark_types:
            wide = []
            if concurrency <= 4:
                wide = range(1, concurrency)
            if 4 < concurrency <= 16:
                wide = range(1, 4) + range(4, concurrency, 2)
            if concurrency > 16:
                wide = range(1, 4) + range(4, 16, 2) + range(16, concurrency, 4)
            for breadth in wide:
                for use_grpc in [True, False]:

                    benchmarks_list.append(
                        WideGraphBenchmark(
                            requests_limit=int(120000 * time_multiplier),
                            desc="Breadth.{}nodes,1KB, use_grpc = {}".format(breadth, use_grpc),
                            use_grpc=use_grpc,
                            resp_size_distribution='1024,0,0',
                            resp_time_distribution='0,0,0',
                            plan_id=base_benchmark_plan,
                            concurrency=1,
                            breadth=breadth,
                            resources=deepcopy(common_resources)
                        )
                    )

        # Breadth with delay
        if BREADTH_WITH_DELAY_BLIST[0] in benchmark_types:
            wide = [25, 50, 75, 100, 125, 150]
            for breadth in wide:
                for use_grpc in [True, False]:
                    benchmarks_list.append(
                        WideGraphBenchmark(
                            requests_limit=int(4000 * time_multiplier),
                            desc="BreadthWithDelay.{}nodes,1KB,Delay10.10.0, use_grpc = {}".format(
                                breadth, use_grpc),
                            use_grpc=use_grpc,
                            resp_size_distribution='1024,0,0',
                            resp_time_distribution='10,10,0',
                            plan_id=base_benchmark_plan,
                            concurrency=1,
                            breadth=breadth,
                            store_unistat=False,
                            resources=deepcopy(common_resources)
                        )
                    )

        # Balancing contention
        if BALANCING_CONT_BLIST[0] in benchmark_types:
            schemes = ['urandom', 'rrobin', 'weighted', 'weighted-fd']
            perftest_servants = [(True, 1, '1024,0,0', '0,0,0', 0)] * concurrency
            for scheme in schemes:
                for use_grpc in [True, False]:
                    benchmarks_list.append(
                        BalancingBenchmark(
                            requests_limit=int(200000 * concurrency * time_multiplier),
                            desc="BalancingContention.Scheme:{}, use_grpc = {}".format(scheme, use_grpc),
                            perftest_servants=perftest_servants,
                            use_grpc=use_grpc,
                            plan_id=base_benchmark_plan,
                            concurrency=concurrency,
                            source_codecs=['null'],
                            balancing_scheme=scheme,
                            resources=deepcopy(common_resources)
                        )
                    )

        # Balancing Efficiency
        if BALANCING_EFFICIENCY_BLIST[0] in benchmark_types:
            schemes = ['urandom', 'rrobin', 'weighted', 'weighted-fd']
            perftest_servants = [
                (True, concurrency, '1024,0,0', '0,0,0', 0),
                (True, concurrency, '1024,0,0', '5,0,0', 0),
                (True, concurrency, '1024,0,0', '50,0,0', 0),
                (False, concurrency, '1024,0,0', '0,0,0', 0)]
            for scheme in schemes:
                for use_grpc in [True, False]:
                    benchmarks_list.append(
                        BalancingBenchmark(
                            requests_limit=int(25000 * concurrency * time_multiplier),
                            desc="BalancingEfficiency.Scheme:{}, use_grpc = {}".format(scheme, use_grpc),
                            perftest_servants=perftest_servants,
                            use_grpc=use_grpc,
                            plan_id=base_benchmark_plan,
                            concurrency=concurrency,
                            source_codecs=['null'],
                            timeout=10,
                            balancing_scheme=scheme,
                            resources=deepcopy(common_resources)
                        )
                    )

        # Request Fraction
        if REQUEST_FRACTION_BLIST[0] in benchmark_types:
            threads = (concurrency + 1) / 2
            tests = [
                # (is_active, concurrency, resp_size_distr, resp_time_distr, trash)
                ("w/o test source", None),
                ("fast success", (True, threads, "1024,0,0", "0,0,0", 0)),
                ("fast refused", (False, threads, "1024,0,0", "0,0,0", 0)),
                ("slow reply", (True, threads, "1024,0,0", "10,0,0", 0))
            ]
            for test_name, test in tests:
                for use_grpc in [True, False]:
                    benchmarks_list.append(
                        RequestFractionBenchmark(
                            requests_limit=int(100000 * concurrency * time_multiplier),
                            desc="RequestFraction.Scheme:{}, use_grpc = {}".format(test_name, use_grpc),
                            use_grpc=use_grpc,
                            main_servant=(True, threads, "1024,0,0", "0,0,0", 0),
                            test_servant=test,
                            plan_id=base_benchmark_plan,
                            concurrency=threads,
                            timeout=100,
                            resources=deepcopy(common_resources)
                        )
                    )

        if HTTP_ADAPTER_BLIST[0] in benchmark_types:
            threads = (concurrency + 1) / 2

            http_adapter_resources = deepcopy(common_resources)
            http_adapter_resources["http_adapter"] = {
                "type": "http_adapter",
                "binary": {"id": self.ctx[HttpAdapterBundle.name], "type": "file"}
            }
            logging.debug("http_adapter_resources: {}".format(http_adapter_resources))

            benchmarks_list.append(
                HttpAdapterBasicBenchmark(
                    requests_limit=int(1000 * concurrency * 5 * time_multiplier),
                    desc="Http Adapter basic benchmark",
                    dplanner_executable_id=self.ctx.get(DplannerExecutable.name),
                    concurrency=threads,
                    resources=deepcopy(http_adapter_resources)
                )
            )

            benchmarks_list.append(
                HttpAdapterBasicBenchmark(
                    requests_limit=int(1000 * concurrency * 5 * time_multiplier),
                    desc="Http Adapter basic benchmark (no compression)",
                    dplanner_executable_id=self.ctx.get(DplannerExecutable.name),
                    concurrency=threads,
                    compression=False,
                    resources=deepcopy(http_adapter_resources)
                )
            )
            #
            # benchmarks_list.append(
            #     HttpAdapterBasicBenchmark(
            #         desc="Http Adapter benchmark (shoot to perftest servant)",
            #         requests_limit=int(1000 * concurrency * 5 * time_multiplier),
            #         dplanner_executable_id=self.ctx.get(DplannerExecutable.name),
            #         concurrency=threads,
            #         shoot_to_perf_servants=True,
            #         resources=deepcopy(http_adapter_resources)
            #     )
            # )
            #
            # for name, tp, stp, ccatp, bsbp, bbc in [
            #         ('Trash', 1.0, 0.0, 0.0, 0.0, 0),
            #         ('StructuredTrash', 0.0, 1.0, 0.0, 0.0, 0),
            #         ('ConsistencyCheckAwareTrash', 0.0, 0.0, 1.0, 0.0, 0),
            #         ('BreakSomeBytes.', 0.0, 0.0, 0.0, 1.0, 10)]:
            #     benchmarks_list.append(HttpAdapterBasicBenchmark(
            #         desc="HttpAdapter Fuzzing.{}.100%.8KB".format(name),
            #         trash_prob=tp,
            #         structured_trash_prob=stp,
            #         consistency_check_aware_trash_prob=ccatp,
            #         break_some_bytes_prob=bsbp,
            #         broken_byte_count=bbc,
            #         resp_size_distribution='8192,0,0',
            #         resp_time_distribution='0,0,0',
            #         requests_limit=int(1000 * concurrency * 5 * time_multiplier),
            #         dplanner_executable_id=self.ctx.get(DplannerExecutable.name),
            #         concurrency=threads,
            #         shoot_to_perf_servants=True,
            #         resources=deepcopy(http_adapter_resources)
            #     ))
            #
            # # Sizes
            # data = [(1024, 100000), (8192, 50000), (32768, 25000), (262144, 6000)]
            # for rsize, reqs in data:
            #     benchmarks_list.append(
            #         HttpAdapterBasicBenchmark(
            #             desc="HttpAdapter Size.{}KB".format(rsize / 1024),
            #             resp_size_distribution='{},0,0'.format(rsize),
            #             resp_time_distribution='0,0,0',
            #             requests_limit=int(500 * concurrency * time_multiplier),
            #             dplanner_executable_id=self.ctx.get(DplannerExecutable.name),
            #             concurrency=threads,
            #             shoot_to_perf_servants=True,
            #             resources=deepcopy(http_adapter_resources)
            #         )
            #     )
            #
            # # Delays
            # for delay in [10, 25, 50]:
            #     rlim = int(30000 * concurrency * time_multiplier / delay)
            #     benchmarks_list.append(
            #         HttpAdapterBasicBenchmark(
            #             desc="HttpAdapter Delay.{}ms,1KB".format(delay),
            #             resp_size_distribution='1024,0,0',
            #             resp_time_distribution='{},0,0'.format(delay),
            #             requests_limit=rlim,
            #             dplanner_executable_id=self.ctx.get(DplannerExecutable.name),
            #             concurrency=threads,
            #             shoot_to_perf_servants=True,
            #             resources=deepcopy(http_adapter_resources)
            #         )
            #     )

        logging.debug("Benchmarks: {}".format(benchmarks_list))
        return benchmarks_list

    def start_benchmarks(self, benchmarks_list):
        tasks = {}

        for bench in benchmarks_list:
            bench.app_host_bundle_id = self.ctx[AppHostBundle.name]
            bench.app_host_test_servants_id = self.ctx[AppHostTestServants.name]
            bench.prepare_plan(self)
            task = bench.create_task(self)
            tasks[bench.desc] = task.id
            logging.info('starting benchmark task {}'.format(task))
        return tasks

    def on_execute(self):
        children = self.ctx.get('children')
        blist = self.ctx.get(BenchmarkList.name, ALL_BLIST[0])
        if blist == ALL_BLIST[0]:
            blist = get_benchmark_keys(ALL_APPHOST)
        elif blist == MIN_BLIST[0]:
            blist = get_benchmark_keys(MINIMAL_APPHOST)
        else:
            blist = [blist]

        logging.debug("Benchmarks to run: {}".format(blist))

        if not hasattr(self.ctx, AppHostBundle.name):
            self.ctx[AppHostBundle.name] = AppHostBundle.get_resource_from_ctx(self.ctx)

        if not hasattr(self.ctx, HttpAdapterBundle.name):
            self.ctx[HttpAdapterBundle.name] = HttpAdapterBundle.get_resource_from_ctx(self.ctx)

        if children is None:
            with self.current_action('starting benchmarks'):
                benchmarks_list = self.init_benchmarks(blist)
                self.ctx['children'] = self.start_benchmarks(benchmarks_list)
            with self.current_action('waiting benchmark results'):
                self.wait_tasks(
                    self.ctx['children'].values(),
                    list(self.Status.Group.FINISH + self.Status.Group.BREAK),
                    True)
        else:
            for name, task_id in self.ctx.get('children').items():
                task = channel.sandbox.get_task(task_id)
                logging.debug("{} {} {}".format(name, task_id, task.status))
                if task.status != "FINISHED":
                    raise SandboxTaskFailureError("Failed to launch test {} {}".format(name, task_id))

            with self.current_action('saving tests resource'):
                tests_path = 'tests.json'
                with open(tests_path, 'w') as fh:
                    json.dump(self.ctx['children'], fh)
                resource = self.create_resource(
                    self.descr,
                    tests_path,
                    resource_types.APP_HOST_BENCHMARK_TESTS)
                self.mark_resource_ready(resource)
