import json
import logging
import shutil

from sandbox import sdk2
import sandbox.common.types.resource as ctr
from sandbox.common import errors
import sandbox.common.types.task as ctt
from sandbox.sdk2.helpers import subprocess as sp
from sandbox.sandboxsdk import environments as envs

from sandbox.projects.ads.eshow.common.utils import get_binary_path, get_yt_client
from sandbox.projects.ads.eshow.resources import AdsReachZcCalculatorBinaryV2

from sandbox.projects.ads.eshow.calculate_reach_zc.lib.yttools import read_yt_table_to_file
from sandbox.projects.ads.eshow.calculate_reach_zc.lib.constants import ZC_APC_METHODS, MODE_RUN_APC_CHECK
from sandbox.projects.ads.eshow.calculate_reach_zc.read_table_from_yt import AdsReachZcReadPremappedTableFromYt

CORES_REQ = 32
TMP_JSON = "zc_result.json"


class AdsCalculateReachZcIntermediateExpData(sdk2.Resource):
    """Temp resource to store exp data"""
    pass


class AdsRunApcCheckOnPreparedTable(sdk2.Task):
    """Run apc_check to calculate zC on premapped table"""

    class Requirements(sdk2.Task.Requirements):
        cores = CORES_REQ
        disk_space = 4 << 10
        ram = CORES_REQ << 10
        environments = [envs.PipEnvironment("yandex-yt")]

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 4 * 60 * 60

        binary_resource = sdk2.parameters.Resource(
            "Reach zC calculator binary (if none - will find latest released)",
            resource_type=AdsReachZcCalculatorBinaryV2,
            state=(ctr.State.READY, )
        )
        secret_name = sdk2.parameters.YavSecret(
            "Secret in YAV with that can access yt. Use #key for key",
            required=True
        )

        with sdk2.parameters.Group("Data") as data_params:
            premapped_table = sdk2.parameters.String(
                "Premapped table on YT",
                required=True
            )
            data_download_task = sdk2.parameters.Task(
                "Task for downloading etalon data",
                task_type=AdsReachZcReadPremappedTableFromYt,
                required=True
            )

        with sdk2.parameters.Group("Calculation Parameters") as calc_params:
            zc_method = sdk2.parameters.String(
                "zC calculation method",
                required=True, default="bids_ratio"
            )
            zc_min_conversion_rate = sdk2.parameters.Float(
                "Minimal conversion rate for the group",
                default=0.0005
            )
            zc_min_group_size = sdk2.parameters.Integer(
                "Minimal group size",
                default=1
            )
            zc_bootstrap_iterations = sdk2.parameters.Integer(
                "Bootstrap iterations",
                default=1000
            )

        with sdk2.parameters.Group("YT Parameters") as yt_params:
            yt_proxy = sdk2.parameters.String(
                "YT proxy",
                required=True, default="hahn"
            )

        with sdk2.parameters.Output():
            zc_mean = sdk2.parameters.Float("zC mean")
            zc_std = sdk2.parameters.Float("zC std")

    def on_save(self):
        assert self.Parameters.zc_method in ZC_APC_METHODS, "Unknown zC calculation method: \"{}\"".format(
            self.Parameters.zc_method
        )

    def on_execute(self):
        with self.memoize_stage.waiting_for_etalon_data(max_runs=2):
            logging.info("Checking data download task ({}) status".format(self.Parameters.data_download_task.id))
            status = self.Parameters.data_download_task.status

            if status in ctt.Status.Group.BREAK or status in (ctt.Status.FAILURE, ctt.Status.DELETED):
                raise errors.TaskFailure("Etalon data download task failed, can't proceed")
            elif status in ctt.Status.Group.FINISH:
                logging.info("Etalon data download is completed")
            else:
                logging.info("Data downloading task in progress, waiting")
                raise sdk2.WaitTask(
                    [self.Parameters.data_download_task],
                    (ctt.Status.Group.FINISH, ctt.Status.Group.BREAK),
                    wait_all=True
                )

        with self.memoize_stage.calculating_zc:
            binary_path = get_binary_path(self.Parameters.binary_resource, AdsReachZcCalculatorBinaryV2)
            yt_client = get_yt_client(self.Parameters, config={"read_retries": {"enable": False}})
            exp_data_resource = AdsCalculateReachZcIntermediateExpData(
                self, "Temporary exp data from reach zC", "exp_data", ttl=3
            )
            logging.info("Copying etalon data")

            with open(str(sdk2.ResourceData(self.Parameters.data_download_task.Parameters.dst_resource).path), "rb") as outp:
                with open(str(sdk2.ResourceData(exp_data_resource).path), "wb") as inp:
                    shutil.copyfileobj(fsrc=outp, fdst=inp)

            logging.info("Appending exp data")
            read_yt_table_to_file(
                self.Parameters.premapped_table, str(sdk2.ResourceData(exp_data_resource).path), yt_client,
                header=False, append=True
            )
            logging.info("Successfully read table to file")
            sdk2.ResourceData(exp_data_resource).ready()

            cmd = self._get_arguments(binary_path, str(exp_data_resource.path))
            logging.info("Running zc calculation subroutine with command: {}".format(" ".join(cmd)))

            with sdk2.helpers.ProcessLog(self, logger="zc_calculator") as pl:
                sp.check_call(cmd, stdout=pl.stdout, stderr=pl.stderr)

            with open(TMP_JSON, "r") as fh:
                result = json.load(fh)

            if result:
                logging.info("Successfully parsed output: {}".format(result))
                self.Parameters.zc_mean, self.Parameters.zc_std = result["expectation"], result["stddev"]
            else:
                logging.error("Failed to get result from subprocess")
                self.Parameters.zc_mean, self.Parameters.zc_std = None, None

    @sdk2.header(title="zC calculations results")
    def report(self):
        if not (self.Parameters.zc_mean and self.Parameters.zc_std):
            return "<h2>Nothing calculated yet or calculations failed</h2>"
        else:
            return "\n".join((
                "<h2>zC</h2>",
                "<h3>{} +/- {}</h3>".format(self.Parameters.zc_mean, self.Parameters.zc_std)
            ))

    def _get_arguments(self, binary_path, log_file):
        cmd = [
            binary_path,
            "--exec-mode", MODE_RUN_APC_CHECK,
            "--filepath", log_file,
            "--method", self.Parameters.zc_method,
            "--threads", str(CORES_REQ),
            "--min-conversion-rate", str(self.Parameters.zc_min_conversion_rate),
            "--min-group-size", str(self.Parameters.zc_min_group_size),
            "--bootstrap-iterations", str(self.Parameters.zc_bootstrap_iterations),
            "--dump-zc-to", TMP_JSON
        ]
        return cmd
