from sandbox import sdk2
from sandbox.projects.common import binary_task, task_env
from sandbox.projects.mt.make.util import mount_arc_with_retries, apply_mtdata_updates, run_mt_make_vh_tool

import sandbox.common.types.notification as ctn
import sandbox.common.types.task as ctt


class TrainNmt(binary_task.LastBinaryTaskRelease, sdk2.Task):

    class Parameters(sdk2.Task.Parameters):
        ext_params = binary_task.binary_release_parameters(stable=True)

        description = "Train neural machine translation model"
        owner = "MT"

        notifications = [
            sdk2.Notification(
                statuses=[
                    ctt.Status.FAILURE,
                    ctt.Status.EXCEPTION,
                    ctt.Status.TIMEOUT
                ],
                recipients=["alexeynoskov@yandex-team.ru"],
                transport=ctn.Transport.EMAIL
            )
        ]

        arcadia_branch = sdk2.parameters.String("Arcadia branch to run code from (empty for trunk)")
        mtdata_updates = sdk2.parameters.String("MTData updates in yaml format")

        direction = sdk2.parameters.String("Direction to train", required=True)
        quota = sdk2.parameters.String("Name of nirvana quota", default="mt")
        secret = sdk2.parameters.YavSecret("YAV secret identifier (with optional version)", required=True)

        label = sdk2.parameters.String("Workflow label")
        ticket = sdk2.parameters.StrictString("Startrek ticket", regexp='[A-Z]{2,}-[0-9]+')
        build_vocs = sdk2.parameters.Bool("Build new vocabularies")
        eval_quality = sdk2.parameters.Bool("Evaluate NMT quality after train")
        early_stopping = sdk2.parameters.Bool("Use stricter autostop profile")

        with sdk2.parameters.Output():
            workflow_id = sdk2.parameters.String("Id of started workflow on Nirvana")
            workflow_instance_id = sdk2.parameters.String("Id of started workflow instance on Nirvana")

    class Requirements(task_env.BuildLinuxRequirements):
        pass

    def on_execute(self):
        res = self.build_and_run_graph()
        self.Parameters.workflow_id = res['workflow_id']
        self.Parameters.workflow_instance_id = res['workflow_instance_id']

    def build_and_run_graph(self):
        train_args = [self.Parameters.direction, "--mtdata", "@arcadia", "--mode", "nirvana"]
        if self.Parameters.label:
            train_args += ['--workflow-label', self.Parameters.label]
        if self.Parameters.ticket:
            train_args += ['--ticket', self.Parameters.ticket]
        if self.Parameters.build_vocs:
            train_args += ['--build-vocs']
        if self.Parameters.eval_quality:
            train_args += ['--eval-nmt']
        if self.Parameters.early_stopping:
            train_args += ['--early-stopping']

        with mount_arc_with_retries(
            arc_token=self.Parameters.secret.data()['arc-token'],
            changeset=self.Parameters.arcadia_branch,
        ) as arc_root:
            if self.Parameters.mtdata_updates:
                apply_mtdata_updates(arc_root, self.Parameters.mtdata_updates)

            return run_mt_make_vh_tool(
                "dict/mt/make/tools/train_nmt/train_nmt",
                train_args,
                arcadia_path=arc_root,
                secrets=self.Parameters.secret.data(),
                quota=self.Parameters.quota,
            )
