import os
import json
import logging
import tarfile
import subprocess
from sandbox import sdk2
from sandbox.common.types import task as ctt
import sandbox.common.types.client as ctc
from sandbox.sandboxsdk import svn
from sandbox.sandboxsdk import environments
from sandbox.sdk2 import yav

from sandbox.projects.vqe.measure_performance import utils


PULSAR_TOKEN_ID = "sec-01en36ndssfcdm692y11vqmypr"


class PerfStatParserRule:
    def __init__(self, name, predicate, parser):
        self.name = name
        self.predicate = predicate
        self.parser = parser


STAT_PARSER_RULES = [
    PerfStatParserRule(
        'task_clock',
        lambda line: 'task-clock' in line,
        lambda line: float(line.split()[0]),
    ),
    PerfStatParserRule(
        'time_user',
        lambda line: 'seconds user' in line,
        lambda line: float(line.split()[0]),
    ),
    PerfStatParserRule(
        'time_sys',
        lambda line: 'seconds sys' in line,
        lambda line: float(line.split()[0]),
    ),
    PerfStatParserRule(
        # sample: " 95.32 msec task-clock # 0.982 CPUs utilized ( +-  1.16% ) "
        'cpu_utilized',
        lambda line: 'CPUs utilized' in line,
        lambda line: float(line.split("#")[1].strip().split()[0]),
    ),
]


def parse_perf_stats(lines):
    result = {}
    for line in lines:
        for rule in STAT_PARSER_RULES:
            if rule.predicate(line):
                result[rule.name] = rule.parser(line)
    return result


class SampleInfo:
    def __init__(self, id_, folder):
        self.id = id_
        self.folder = folder


class VqeInputRecordResource(sdk2.Resource):
    """ Preuploaded archive with single record for VQE"""
    record_id = sdk2.parameters.String


class VqeRawPerfResultResource(sdk2.Resource):
    """ Result of raw perf execution on VQE binary"""


class VqePerfResultResource(sdk2.Resource):
    """ Result of parsed perf execution on VQE binary"""


class VqeBinary(sdk2.Resource):
    """vqe binary
    """
    executable = True


class VqePerformanceOnSingleRecord(sdk2.Task):
    """Task to measure VQE performance"""
    # enable yav support

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 1024
        disk_space = 1000
        # uncomment before push to production sandbox
        client_tags = ctc.Tag.LINUX_TRUSTY & ctc.Tag.INTEL_E5_2650

        environments = [
            environments.PipEnvironment("yandex-pulsar", version="0.2.2.post7457710"),
        ]

    class Parameters(sdk2.Task.Parameters):
        vqe_arcadia_folder = sdk2.parameters.String("path to required vqe folder", default="voicetech/vqe/yandex_vqe", required=True)
        binary_name = sdk2.parameters.String("Binary name(if not equal to folder name", default="", required=False)

        vqe_binary_resource_id = sdk2.parameters.Integer(
            "VQE_BINARY resource id. If id is provided this resource will be used instead of building executable from Arcadia",
            default=0, required=False,
        )

        svn_reviosion = sdk2.parameters.String("revision", default="", required=False)
        preset = sdk2.parameters.String("preset", default="yandexstation", required=True)
        n_runs = sdk2.parameters.Integer("Run test N times", default=10, required=True)
        flags = sdk2.parameters.String("VQE flags", default="", required=False)
        record_ids = sdk2.parameters.String("Record Id, can be comma (`,`) seprated", default="", required=True)
        send_result_to_pulsar = sdk2.parameters.Bool("Send result to pulsar")
        pulsar_tags = sdk2.parameters.String("Pulsar tags, comma separated", default="", required=False)
        pulsar_token = sdk2.parameters.YavSecret("Pulsar secret id", default=PULSAR_TOKEN_ID)

    def _extract_tar(self, src, dst=None):
        logging.info('Extracting tar files from "{src}" to "{dst}"'.format(src=src, dst=dst))
        if not dst:
            idx = src.find('.tar')
            dst = src[:idx]
        tar = tarfile.open(src)
        tar.extractall(path=dst)
        tar.close()
        return dst

    def _run_build_task(self):
        target = self.Parameters.vqe_arcadia_folder
        if self.Parameters.binary_name:
            arts = target + "/" + self.Parameters.binary_name
        else:
            arts = target + "/" + target.split("/")[-1]

        checkout_arcadia_from_url = "arcadia:/arc/trunk/arcadia"
        if self.Parameters.svn_reviosion and self.Parameters.svn_reviosion != "HEAD":
            checkout_arcadia_from_url += "@" + self.Parameters.svn_reviosion

        task = sdk2.Task["YA_MAKE_2"](
            self,
            checkout_arcadia_from_url=checkout_arcadia_from_url,
            description="build for VQE performance test",
            result_single_file=True,
            targets=target,
            arts=arts,
            owner=self.owner,
            build_type="release",
            build_arch="linux",
            use_aapi_fuse=True,
            use_arc_instead_of_aapi=True,
            aapi_fallback=True,
        ).save().enqueue()
        return task.id

    def build_vqe(self):
        logging.info("Start build binary")

        with self.memoize_stage.create_children_vqe:
            build_vqe_task = self._run_build_task()
            self.Context.build_vqe_task_id = build_vqe_task
            raise sdk2.WaitTask(
                [build_vqe_task],
                [ctt.Status.Group.FINISH, ctt.Status.Group.BREAK],
                wait_all=True,
            )

        vqe_resource = sdk2.Resource.find(
            task_id=self.Context.build_vqe_task_id,
            type='ARCADIA_PROJECT',
        ).first()
        vqe_binary_path = str(sdk2.ResourceData(vqe_resource).path)
        return vqe_binary_path

    def load_samples(self):
        logging.info("Start load wavs")
        out_folder_pattern = "./input_record_{}"
        result_samples = []
        record_ids = [it.strip() for it in self.Parameters.record_ids.split(",") if it]
        for idx, id_ in enumerate(record_ids):
            sound_resource = sdk2.Resource.find(
                VqeInputRecordResource,
                attrs=dict(record_id=id_),
            ).first()
            src_path = str(sdk2.ResourceData(sound_resource).path)
            dst_path = out_folder_pattern.format(idx)
            self._extract_tar(src_path, dst_path)
            result_samples.append(SampleInfo(id_, dst_path))
        logging.info("Done load wavs")
        return result_samples

    def load_predefined_vqe_binary(self, resource_id):
        logging.info("Start load predefined vqe binary")
        resource = sdk2.Resource.find(
            VqeBinary,
            id=resource_id,
        ).first()

        binary_path = str(sdk2.ResourceData(resource).path)
        logging.info("Done load predefined vqe binary")
        return binary_path

    def get_ya_tool(self):
        logging.info("Start get ya tool")
        return svn.Arcadia.export('arcadia:/arc/trunk/arcadia/ya', os.path.realpath('ya'))

    def run_perf(self, command):
        ya = self.get_ya_tool()
        n_runs = str(self.Parameters.n_runs)
        cmd = ya + " tool perf stat -r {n_runs} ".format(n_runs=n_runs) + command
        process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        process.wait()
        out, err = process.communicate()
        return out, err

    def on_execute(self):
        logging.info("Start on_execute")
        logging.info("VQE is {}".format(self.Parameters.vqe_arcadia_folder))

        if self.Parameters.vqe_binary_resource_id > 0:
            logging.info("Load existing VQE_BINARY resource with id {}".format(self.Parameters.vqe_binary_resource_id))
            vqe_binary_path = self.load_predefined_vqe_binary(self.Parameters.vqe_binary_resource_id)
        else:
            logging.info("Build vqe binary from arcadia")
            vqe_binary_path = self.build_vqe()

        # TODO: Test that it works
        samples = self.load_samples()
        result_data = {}
        avg_data = {}
        for sample in samples:

            cmd = vqe_binary_path + " --dir " + sample.folder + " -p " + self.Parameters.preset + " " + self.Parameters.flags
            _, err = self.run_perf(cmd)
            perf_stderr = sdk2.ResourceData(VqeRawPerfResultResource(self, "VQE Perf stderr {}".format(sample.id), "perf_output_{}.txt".format(sample.id)))
            perf_stderr.path.write_bytes(str(err))

            parsed_result = parse_perf_stats(err.split("\n"))

            for key, val in parsed_result.items():
                result_data[sample.id + "_" + key] = val
                if key in avg_data:
                    avg_data[key] += val
                else:
                    avg_data[key] = val

        for key, val in avg_data.items():
            result_data["avg_" + key] = float(val) / len(samples)

        parsed_perf = sdk2.ResourceData(VqePerfResultResource(self, "VQE Perf result", "perf_output.json"))
        parsed_perf.path.write_bytes(json.dumps(result_data, indent=4))

        if self.Parameters.send_result_to_pulsar:
            secret = yav.Secret(PULSAR_TOKEN_ID)
            pulsar_token = self.Parameters.pulsar_token.data()["secret"]
            assert pulsar_token, "ERROR, EMPTY PULSAR TOKEN!"
            os.environ["PULSAR_TOKEN"] = pulsar_token

            dataset_name = "_".join(sorted([sample.id for sample in samples]))

            vqe_name = self.Parameters.vqe_arcadia_folder.split("/")[-1]
            tags = [it.strip() for it in self.Parameters.pulsar_tags.split(",") if it]
            utils.send_to_pulsar(
                preset=self.Parameters.preset,
                vqe_name=vqe_name,
                dataset_name=dataset_name,
                n_runs=self.Parameters.n_runs,
                flags=self.Parameters.flags,
                metrics=result_data,
                user_tags=tags,
            )

        logging.info("Done on_execute")


__TASK__ = VqePerformanceOnSingleRecord
