from sandbox import sdk2
from sandbox.projects.common import binary_task, task_env
from sandbox.projects.mt.make.util import mount_arc_with_retries, apply_mtdata_updates, run_mt_make_vh_tool

import sandbox.common.types.notification as ctn
import sandbox.common.types.task as ctt


class EvalNmt(binary_task.LastBinaryTaskRelease, sdk2.Task):

    class Parameters(sdk2.Task.Parameters):
        ext_params = binary_task.binary_release_parameters(stable=True)

        description = "Eval neural machine translation model"
        owner = "MT"

        notifications = [
            sdk2.Notification(
                statuses=[
                    ctt.Status.FAILURE,
                    ctt.Status.EXCEPTION,
                    ctt.Status.TIMEOUT
                ],
                recipients=["alexeynoskov@yandex-team.ru"],
                transport=ctn.Transport.EMAIL
            )
        ]

        arcadia_branch = sdk2.parameters.String("Arcadia branch to run code from (empty for trunk)")
        mtdata_updates = sdk2.parameters.String("MTData updates in yaml format")

        direction = sdk2.parameters.String("Direction to evaluate quality", required=True)
        quota = sdk2.parameters.String("Name of nirvana quota", default="mt-eval")
        secret = sdk2.parameters.YavSecret("YAV secret identifier (with optional version)", required=True)

        with sdk2.parameters.Output():
            workflow_id = sdk2.parameters.String("Id of started workflow on Nirvana")
            workflow_instance_id = sdk2.parameters.String("Id of started workflow instance on Nirvana")

    class Requirements(task_env.BuildLinuxRequirements):
        pass

    def on_execute(self):
        res = self.build_and_run_graph()
        self.Parameters.workflow_id = res['workflow_id']
        self.Parameters.workflow_instance_id = res['workflow_instance_id']

    def build_and_run_graph(self):
        with mount_arc_with_retries(
            arc_token=self.Parameters.secret.data()['arc-token'],
            changeset=self.Parameters.arcadia_branch,
        ) as arc_root:
            if self.Parameters.mtdata_updates:
                apply_mtdata_updates(arc_root, self.Parameters.mtdata_updates)

            return run_mt_make_vh_tool(
                "dict/mt/make/tools/eval_nmt/eval_nmt",
                [self.Parameters.direction, '--mtdata', '@arcadia', '--mode', 'nirvana', '--directory', '.', '--use-local-model', '--use-local-vocs', '--nmt', '--calc-metric', 'BLEU'],
                arcadia_path=arc_root,
                secrets=self.Parameters.secret.data(),
                quota=self.Parameters.quota,
            )
