from __future__ import print_function
import json
import logging
import os
import sys
import yaml
import time
import traceback

from sandbox import common, sdk2
from sandbox.projects.common import binary_task, task_env
from sandbox.projects.common.vcs.arc import Arc
from sandbox.projects.mt.make.util import MountArc, run_mt_make_tool, run_mt_make_test, run_mt_make_vh_tool

import sandbox.common.types.client as ctc
import sandbox.common.types.notification as ctn
import sandbox.common.types.task as ctt


class PrepareNmtRun(binary_task.LastBinaryTaskRelease, sdk2.Task):
#class PrepareNmtRun(sdk2.Task):

    class Parameters(sdk2.Task.Parameters):
        ext_params = binary_task.binary_release_parameters(stable=True)

        description = "Update db config for direction unsing update_db_config tool."
        owner = "MT"

        notifications = [
            sdk2.Notification(
                statuses=[
                    ctt.Status.FAILURE,
                    ctt.Status.EXCEPTION,
                    ctt.Status.TIMEOUT
                ],
                recipients=["dronte@yandex-team.ru"],
                transport=ctn.Transport.EMAIL
            )
        ]

        direction = sdk2.parameters.String("Direction to update corpora config for", required=True)
        run_update_db_config_tool = sdk2.parameters.Bool("Run update_db_config tool", default=True)
        multilingual = sdk2.parameters.Bool("Direction as multilingual")
        update_kwyt_corpora = sdk2.parameters.Bool("Update kwyt_corpora")
        update_opus_corpora = sdk2.parameters.Bool("Update opus_corpora")
        update_commoncrawl_corpora = sdk2.parameters.Bool("Update commoncrawl_corpora")

        update_nmt_model = sdk2.parameters.Bool("Update nmt_model")

        train_corpora_scored = sdk2.parameters.Bool("Update train_corpora_scored")
        train_corpora_tags = sdk2.parameters.Bool("Update train_corpora_tags")
        laser_score_lower_bound = sdk2.parameters.Float("laser score lower bound", default=None)
        laser_score_upper_bound = sdk2.parameters.Float("laser score upper bound", default=None)
        dcce_score_lower_bound = sdk2.parameters.Float("dcce score lower bound", default=None)
        dcce_score_upper_bound = sdk2.parameters.Float("dcce score upper bound", default=None)
        labse_score_lower_bound = sdk2.parameters.Float("labse score lower bound", default=None)
        labse_score_upper_bound = sdk2.parameters.Float("labse score upper bound", default=None)

        lang_id_fast_text_min_src_prob = sdk2.parameters.Float("lang id filter min src prob", default=None)
        lang_id_fast_text_min_dst_prob = sdk2.parameters.Float("lang id filter min dst prob", default=None)

        bicleaner_min_score = sdk2.parameters.Float("Min score for BiCleaner filter", default=None)

        train_options = sdk2.parameters.String("Python expression for train options (better use single quotes if need)", default=None)

        sed_pattern = sdk2.parameters.String("Sed patthern. This string(or regexp) will be replaced by sed_repl.")
        sed_repl = sdk2.parameters.String("String to replace sed_pattern with. Works only if sed_pattern is specified.")

        source_arcadia_branch = sdk2.parameters.String("Arcadia branch to use_as_base", default='trunk')
        arcadia_branch = sdk2.parameters.String("Arcadia branch to commit changes to", required=True)
        commit_message = sdk2.parameters.String("Commit message", required=True)
        allow_empty_diff = sdk2.parameters.Bool("Allow empty diff on commit", default=False)

        run_tests = sdk2.parameters.Bool("Run DB tests after update", default=True)
        run_score_corpora = sdk2.parameters.Bool("Run score corpora on direction")

        secret = sdk2.parameters.YavSecret("YAV secret identifier (with optional version)", required=True)

        with sdk2.parameters.Output:
            branch_diff_arcanum_link = sdk2.parameters.String("Arcanum link to trunk-branch diff")
            branch_history_arcanum_link = sdk2.parameters.String("Arcanum link to branch history")

            commit_hash = sdk2.parameters.String("Arcadia commit hash")
            commit_arcanum_link = sdk2.parameters.String("Arcanum link to commit")

            scoring_workflow_id = sdk2.parameters.String("Nirvana workflow id")
            scoring_workflow_instance_id = sdk2.parameters.String("Nirvana workflow instance")
            scoring_workflow_link = sdk2.parameters.String("Nirvana workflow link to nirvana")

    class Requirements(task_env.TinyRequirements):
        pass

    def run_update_db_config_tool(self, arc_root):
        parameters = [
            self.Parameters.direction,
            "--db-content-path", os.path.join(arc_root, 'dict', 'mt', 'make', 'db', 'content')
        ]

        if self.Parameters.multilingual:
            parameters += ['--multilingual']

        if self.Parameters.update_kwyt_corpora:
            parameters += ['--key-to-update', 'kwyt_corpora']

        if self.Parameters.update_opus_corpora:
            parameters += ['--key-to-update', 'opus_corpora']

        if self.Parameters.update_commoncrawl_corpora:
            parameters += ['--key-to-update', 'commoncrawl_corpora']

        if self.Parameters.update_nmt_model:
            parameters += ['--key-to-update', 'nmt_model']

        if self.Parameters.train_corpora_scored:
            parameters += ['--train-corpora-scored']

        if self.Parameters.train_corpora_tags:
            parameters += ['--train-corpora-tags']

        if self.Parameters.laser_score_lower_bound is not None:
            parameters += ['--laser-score-lower-bound', str(self.Parameters.laser_score_lower_bound)]
        if self.Parameters.laser_score_upper_bound is not None:
            parameters += ['--laser-score-upper-bound', str(self.Parameters.laser_score_upper_bound)]
        if self.Parameters.dcce_score_lower_bound is not None:
            parameters += ['--dcce-score-lower-bound', str(self.Parameters.dcce_score_lower_bound)]
        if self.Parameters.dcce_score_upper_bound is not None:
            parameters += ['--dcce-score-upper-bound', str(self.Parameters.dcce_score_upper_bound)]
        if self.Parameters.labse_score_lower_bound is not None:
            parameters += ['--labse-score-lower-bound', str(self.Parameters.labse_score_lower_bound)]
        if self.Parameters.labse_score_upper_bound is not None:
            parameters += ['--labse-score-upper-bound', str(self.Parameters.labse_score_upper_bound)]
        if self.Parameters.lang_id_fast_text_min_src_prob is not None:
            parameters += ['--lang-id-fast-text-min-src-prob', str(self.Parameters.lang_id_fast_text_min_src_prob)]
        if self.Parameters.lang_id_fast_text_min_dst_prob is not None:
            parameters += ['--lang-id-fast-text-min-dst-prob', str(self.Parameters.lang_id_fast_text_min_dst_prob)]
        if self.Parameters.bicleaner_min_score is not None:
            parameters += ['--bicleaner-min-score', str(self.Parameters.bicleaner_min_score)]

            if self.Parameters.train_options is not None:
                parameters += ['--train-options', str(self.Parameters.train_options)]

            if self.Parameters.sed_pattern:
                parameters += ['--sed-pattern', self.Parameters.sed_pattern]
                parameters += ['--sed-repl', self.Parameters.sed_repl]
        run_mt_make_tool(
            "dict/mt/make/tools/update_db_config/update_db_config",
            parameters,
            arcadia_path=arc_root,
            secrets=self.Parameters.secret.data()
        )


    def on_execute(self):
        with MountArc(
            branch=self.Parameters.source_arcadia_branch,
            create_branch_if_not_exists=False,
            arc_token=self.Parameters.secret.data()['arc-token']
        ) as arc:
            fail_if_branch_exists = self.Parameters.source_arcadia_branch != self.Parameters.arcadia_branch
            arc.checkout(self.Parameters.arcadia_branch, fail_if_branch_exists=fail_if_branch_exists)

            self.Parameters.branch_history_arcanum_link = 'https://a.yandex-team.ru/arc_vcs/history/?peg=' + self.Parameters.arcadia_branch

            arc_root = arc.mount_path
            with arc.commit(message=self.Parameters.commit_message,
                            allow_empty=self.Parameters.allow_empty_diff):

                if self.Parameters.run_update_db_config_tool:
                    self.run_update_db_config_tool(arc_root)

                if self.Parameters.run_tests:
                    run_mt_make_test(
                        "dict/mt/make/db/tests",
                        ['--test-size-timeout=small=300'],
                        arcadia_path=arc_root,
                        secrets=self.Parameters.secret.data()
                    )

            self.Parameters.commit_hash = arc.last_commit_hash
            self.Parameters.commit_arcanum_link = 'https://a.yandex-team.ru/arc_vcs/commit/' + self.Parameters.commit_hash

            self.Parameters.branch_diff_arcanum_link = 'https://a.yandex-team.ru/arc_vcs/diff/?prevRev=' + arc.merge_base + '&rev=' + (self.Parameters.commit_hash or arc.on_mount_head)
            # if commit exception is thrown, it will be throuwn out of this scope

            if self.Parameters.run_score_corpora:
                workflow_info = run_mt_make_vh_tool(
                    "dict/mt/make/tools/score_corpora/score_corpora",
                    [self.Parameters.direction, '--mtdata', '@arcadia'],
                    arcadia_path=arc_root,
                    secrets=self.Parameters.secret.data()
                   )
                self.Parameters.scoring_workflow_id = workflow_info['workflow_id']
                self.Parameters.scoring_workflow_instance_id = workflow_info['workflow_instance_id']
                if self.Parameters.scoring_workflow_id:
                    self.Parameters.scoring_workflow_link = (
                        "https://nirvana.yandex-team.ru/flow/" + workflow_info["workflow_id"] + "/" + workflow_info["workflow_instance_id"] + "/graph")
