import json
import os
import sandbox.projects.common.fasthttp.queriesloader as queriesloader
# import sandbox.projects.common.wizard.parameters as w_params
# from sandbox.projects import resource_types
from sandbox.projects.common.wizard import wizard_runner
from sandbox.projects.common.wizard import utils as wizard_utils
from sandbox import sdk2
import logging
import time
import duration_meta


class RulesDurationResource(sdk2.Resource):
    """ Rules duration resource """


def rules_selector(data):
    try:
        wizard_output = json.loads(data)
    except:
        return None, "Invalid json"
    rules = wizard_output.get("rules", None)
    if rules is None: return None, "No rules field"
    return rules, ""


def get_rules_durations(rules):
    for rule in rules:
        if not rules[rule]: continue
        duration = int(rules[rule].get("RuleDuration", -1000)) / 1000
        yield rule, duration


def make_request(host, port, req):
    return queriesloader.GetRequest("http://{}:{}{}&format=json&wizextra=enable_rule_duration&nocache=da".format(
        host,
        port,
        req.strip()
    ))


class WizardRulesDurationPerf(sdk2.Task):
    """ Download queries from target beta """

    class Context(sdk2.Task.Context):
        rules_stat = []

    class Requirements(sdk2.Task.Requirements):
        disk_space = 40 * 1024
        client_tags = wizard_utils.ALL_SANDBOX_HOSTS_TAGS

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 6 * 60 * 60
        # custom parameters
        wizard_binary = sdk2.parameters.Resource(
            "Wizard executable",
            # type=resource_types.REMOTE_WIZARD,  # FIXME: invalid argument (SANDBOX-6404)
            required=True
        )
        wizard_config = sdk2.parameters.Resource(
            "Wizard conf",
            # type=w_params.TARGET_CONFIG_TYPES,  # FIXME: invalid argument (SANDBOX-6404)
            required=True
        )
        wizard_shard = sdk2.parameters.Resource(
            "Wizard data",
            # type=resource_types.WIZARD_SHARD,  # FIXME: invalid argument (SANDBOX-6404)
            required=True
        )
        runtime_data = sdk2.parameters.Resource(
            "Wizard runtime-data",
            # type=resource_types.WIZARD_RUNTIME_PACKAGE_UNPACKED  # FIXME: invalid argument (SANDBOX-6404)
        )

        request_repeat = sdk2.parameters.Integer("Repeat per request", default=16, required=True)
        session_repeat = sdk2.parameters.Integer("Repeat per session", default=4, required=True)
        async_size = sdk2.parameters.Integer("Async size", default=32, required=True)
        fail_retries = sdk2.parameters.Integer("Retries on fail", default=32, required=True)

        wizard_queries = sdk2.parameters.Resource("Queries")

    @property
    def footer(self):
        return [{
            "content": {
                "<h3>Rules duration</h3>": self.Context.rules_stat
            }
        }]

    def on_save(self):
        wizard_utils.setup_hosts(self)

    def on_execute(self):
        result_resource = sdk2.ResourceData(RulesDurationResource(
            self, "Rules duration file", "result.json"
        ))

        working_dir = os.path.abspath("tmp")
        os.mkdir(os.path.join(working_dir, 'wizard'))
        os.symlink(str(sdk2.ResourceData(self.Parameters.wizard_shard).path), os.path.join(working_dir, 'wizard/WIZARD_SHARD'))
        if self.Parameters.runtime_data:
            os.symlink(str(sdk2.ResourceData(self.Parameters.runtime_data).path), os.path.join(working_dir, 'wizard.runtime'))

        wizard = wizard_runner.WizardRunner(
            str(sdk2.ResourceData(self.Parameters.wizard_binary).path),
            working_dir,
            str(sdk2.ResourceData(self.Parameters.wizard_config).path),
            working_dir if self.Parameters.runtime_data else None
        )

        wizard.start(wait=False)

        with open(str(sdk2.ResourceData(self.Parameters.wizard_queries).path)) as requests_file:
            requests = list(set([x.strip() for x in requests_file.readlines()]))

        # log index: reqId-session-iteration: ruleId: duration
        result = {
            "rules_list": [],
            "quries_list": requests,
            "log": [
                [
                    [None] * int(self.Parameters.request_repeat)
                    for _2 in range(int(self.Parameters.session_repeat))
                ]
                for _1 in range(len(requests))
            ]
        }
        to_download = [
            make_request("localhost", wizard.port, requests[i / int(self.Parameters.request_repeat)])
            for i in range(int(self.Parameters.request_repeat) * len(requests))
        ]
        raw_log = [None] * len(to_download)

        wizard.wait()
        rule_to_index = {}

        for session_id in range(int(self.Parameters.session_repeat)):

            for query_id, query_result in queriesloader.download_queries(
                    to_download,
                    async_size=int(self.Parameters.async_size),
                    retries=int(self.Parameters.fail_retries),
                    hash_validation_retries=0):
                raw_log[query_id] = query_result

            for query_id in range(len(to_download)):
                query_result = raw_log[query_id]

                query_result, error = rules_selector(query_result)
                if error: continue

                request_id = query_id / int(self.Parameters.request_repeat)
                iteration_id = query_id % int(self.Parameters.request_repeat)

                if not query_result:
                    logging.error("sess_id={} iter_id={} req_id={}: failed".format(session_id, iteration_id, request_id))
                    continue

                duration_info = []
                for rule, duration in get_rules_durations(query_result):
                    rule_id = rule_to_index.get(rule, None)
                    if rule_id is None:
                        rule_id = len(rule_to_index)
                        rule_to_index[rule] = rule_id
                        result["rules_list"].append(rule)
                    while rule_id >= len(duration_info): duration_info.append(-1)
                    duration_info[rule_id] = duration

                result["log"][request_id][session_id][iteration_id] = duration_info

            time.sleep(10)
        logging.info("OK, queries downloaded")

        with open(str(result_resource.path), "w") as fd:
            json.dump(result, fd)
        logging.info("Raw data saved")

        similar_log = duration_meta.select_similar_session(result["log"], int(self.Parameters.session_repeat) / 2)
        logging.info("Processed data saved")

        self.Context.rules_stat = duration_meta.get_rules_stat_table(result["rules_list"], similar_log)
        logging.info("Bye!")
