from sandbox import sdk2

from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common import utils
from sandbox.projects.prs_ops import resources
from sandbox.common.types.client import Tag
from sandbox.projects.common.search import bugbanner2
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox import common
import sandbox.common.types.task as ctt
from sandbox.projects.prs_ops.CompareYtTablesPrsOps import CompareYtTablesPrsOps
from sandbox.projects.release_machine.core import const as rm_const
from sandbox.projects.release_machine.components.configs.prs_ops import PrsOpsCfg
from sandbox.projects.release_machine import rm_notify as rm_notify
from sandbox.projects.common.sdk_compat import task_helper


@rm_notify.notify2()
class PrsOpsFactorDiff(bugbanner2.BugBannerTask):

    class Requirements(sdk2.Task.Requirements):
        ram = 40 * 1024
        disk_space = 5 * 1024
        # clients in LXC containers don't have host names, only v6 address
        client_tags = Tag.GENERIC & Tag.Group.LINUX & ~Tag.LXC

        environments = [
            PipEnvironment('yandex-yt', version="0.8.29.post0", use_wheel=True),
            PipEnvironment('yandex-yt-yson-bindings-skynet', version="0.3.7.post1", use_wheel=True)
        ]

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 5 * 60 * 60
        substitute_ranks = sdk2.parameters.Bool("substitute-ranks", default=True)
        is_queries_full = sdk2.parameters.Bool("is queries-format-full", default=True)
        gdtsf = sdk2.parameters.Bool("enable gather-dynamic-top-sizes-factors", default=False)
        customCGI = sdk2.parameters.String("customCGI")
        selected_slices = sdk2.parameters.String("selected-slices")
        mode = sdk2.parameters.String("Mode to run", default="COMBO", required=True)
        mr_server = sdk2.parameters.String("mr-server", default="hahn.yt.yandex.net")
        save_to = sdk2.parameters.String(
            "result folder on server (dir in //home/prs_ops/)",
            default="111111"
        )
        disable_samohod = sdk2.parameters.Bool("disable samohod on middle", default=True)
        write_mode = sdk2.parameters.String("write mode", default="mr-tsv")
        success = sdk2.parameters.Float("success threshold", default=0.99)
        args = sdk2.parameters.String("additional command line parameters:")

        prs_ops_binary_resource = sdk2.parameters.Resource(
            "prs_ops executable",
            # resourse_type=resources.PRS_OPS_EXECUTABLE,  # FIXME: invalid argument (SANDBOX-6404)
            required=True,
        )
        prs_ops_queries_resource = sdk2.parameters.Resource(
            "prs_ops queries",
            # resourse_type=resources.PRS_OPS_QUERIES,  # FIXME: invalid argument (SANDBOX-6404)
            required=True,
        )
        prs_ops_ratings_resource = sdk2.parameters.Resource(
            "prs_ops ratings",
            # resourse_type=resources.PRS_OPS_RATINGS,  # FIXME: invalid argument (SANDBOX-6404)
            required=True,
        )
        check_diff = sdk2.parameters.Resource(
            "check_diff executable",
            # resourse_type=resources.PRS_OPS_FACTOR_DIFF,  # FIXME: invalid argument (SANDBOX-6404)
        )

    def on_enqueue(self):
        self.Context.save_to = "{}_id={}".format(self.Parameters.save_to, self.id)
        task_helper.ctx_field_set(self, rm_const.COMPONENT_CTX_KEY, PrsOpsCfg.name)

    def on_execute(self):
        self.add_bugbanner(bugbanner2.Banners.PrsOps)
        waited_statuses = set(common.utils.chain(ctt.Status.Group.FINISH, ctt.Status.Group.BREAK))
        tasks_to_wait = []
        with self.memoize_stage.create_tables:
            run_prs_ops = sdk2.Task["RUN_PRS_OPS"]
            child_task1 = run_prs_ops(
                self,
                description="new child",
                owner=self.owner,
                customCGI=self.Parameters.customCGI,
                selected_slices=self.Parameters.selected_slices,
                save_to=self.Parameters.save_to+"_old",
                success=self.Parameters.success,
                args=self.Parameters.args,
                prs_ops_binary_resource=self.Parameters.prs_ops_binary_resource,
                prs_ops_queries_resource=self.Parameters.prs_ops_queries_resource,
                prs_ops_ratings_resource=self.Parameters.prs_ops_ratings_resource,
            )
            tasks_to_wait.append(child_task1.enqueue())
            child_task2 = run_prs_ops(
                self,
                description="old child",
                owner=self.owner,
                customCGI=self.Parameters.customCGI,
                selected_slices=self.Parameters.selected_slices,
                save_to=self.Parameters.save_to+"_new",
                success=self.Parameters.success,
                args=self.Parameters.args,
                prs_ops_binary_resource=utils.last_resource_with_released_attribute(resources.PRS_OPS_EXECUTABLE).id,
                prs_ops_queries_resource=self.Parameters.prs_ops_queries_resource,
                prs_ops_ratings_resource=self.Parameters.prs_ops_ratings_resource,
            )
            tasks_to_wait.append(child_task2.enqueue())
            raise sdk2.WaitTask(tasks_to_wait, waited_statuses)

        if self.find(status=ctt.Status.Group.SUCCEED).count < 2:
            eh.check_failed("At least one of the task failed")
        else:
            children = list(self.find())
            path_1 = children[0].Context.save_to
            path_2 = children[1].Context.save_to
            with self.memoize_stage.compare_tables:
                compare_task = sdk2.Task["COMPARE_YT_TABLES_PRS_OPS"]
                tasks_to_wait = [
                    compare_task(
                        self,
                        description="diff from tasks {}".format(self.id),
                        owner=self.owner,
                        path_1=path_1,
                        path_2=path_2,
                        check_diff=self.Parameters.check_diff,
                    ).enqueue()
                ]
                raise sdk2.WaitTask(tasks_to_wait, waited_statuses)
        if self.find(status=ctt.Status.Group.SUCCEED).count < 3:
            eh.check_failed("Diff task has a problem")
        elif list(self.find(CompareYtTablesPrsOps))[0].Context.has_diff:
            eh.check_failed("Diff task has diff")
