from sandbox import sdk2
from sandbox.projects.common import constants as sandbox_constants
from sandbox.projects.common import error_handlers as eh
from sandbox.projects.resource_types import ARCADIA_PROJECT, OTHER_RESOURCE
from sandbox.projects.voicetech.common.nirvana import Workflow as NirvanaWorkflow
from sandbox.sandboxsdk.errors import SandboxTaskFailureError


import json
import logging
import os
import sandbox.projects.voicetech.common.asr_utils as asr_utils
import sandbox.projects.voicetech.resource_types as voicetech_resource_types
import subprocess as sp


class RegularMetricsRunAsrCli(sdk2.Task):
    _RESOURCE_MAP = {
        'RU_DIALOG_GENERAL_E2E': voicetech_resource_types.VOICETECH_ASR_RU_RU_DIALOGENERALGPU,
        'RU_QUASAR_GENERAL_E2E': voicetech_resource_types.VOICETECH_ASR_RU_RU_QUASARGENERALGPU,
        'RU_TV_GENERAL_E2E': voicetech_resource_types.VOICETECH_ASR_RU_RU_TVGENERALGPU,
        'MULTITOPIC_RU_QUASAR_TV_E2E': voicetech_resource_types.VOICETECH_ASR_MULTITOPIC_RU_RU_QUASARGENERALGPU_RU_RU_TVGENERALGPU,
        'RU_CHATS_E2E': voicetech_resource_types.VOICETECH_ASR_RU_RU_CHATSGPU,
        'RU_CHATS_E2E_V2': voicetech_resource_types.VOICETECH_ASR_RU_RU_CHATSGPU_V2,
        'MULTITOPIC_RU_TR_DIALOG_MAPS_E2E': voicetech_resource_types.VOICETECH_ASR_MULTITOPIC_RU_RU_DIALOGMAPSGPU_TR_TR_DIALOGMAPSGPU,
        'MULTITOPIC_CHATS_E2E': voicetech_resource_types.VOICETECH_ASR_MULTITOPIC_CHATS,
        'BIO_QUASAR': voicetech_resource_types.VOICETECH_BIO_QUASAR,
        'TEST_LINGWARE': voicetech_resource_types.VOICETECH_ASR_TEST_LINGWARE,
        'INTL_QUASAR_GENERAL_E2E': voicetech_resource_types.VOICETECH_ASR_INTL_QUASARGENERALGPU,
        'MULTILANG_VIDEO_ZEN': voicetech_resource_types.VOICETECH_ASR_ZEN_MULTILANG_VIDEO,
    }

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 18000

        checkout_arcadia_from_url = sdk2.parameters.String(sandbox_constants.ARCADIA_URL_KEY, required=True)
        arcadia_patch = sdk2.parameters.String(sandbox_constants.ARCADIA_PATCH_KEY, required=False)

        checks_limit = sdk2.parameters.Integer('Check nirvana graph times (limit)', default=30, required=True)
        checks_period = sdk2.parameters.Integer('Check nirvana graph period (seconds)', default=10 * 60, required=True)

        config_path = sdk2.parameters.String(
            'Arcadia path to config',
            default='voicetech/asr/tools/regular_metrics_run/configs/ru_quasar_general_e2e-accept.json',
            required=False
        )

        config_data = sdk2.parameters.JSON(
            'Config as json',
            default={},
            required=False
        )

        lingware_type = sdk2.parameters.String(
            'Lingware type',
            default='RU_QUASAR_GENERAL_E2E',
            required=True
        )

        yt_proxy = sdk2.parameters.String('Yt proxy', default='hahn', required=True)

        nirvana_token_vault = sdk2.parameters.String(
            'Nirvana oauth token vault name',
            default='robot-acoustic-team-nirvana-token',
            required=True,
        )

    def on_execute(self):
        if not self.Parameters.config_path and not self.Parameters.config_data:
            raise SandboxTaskFailureError('One of config_path and config_data should be defined')

        arcadia_url = self.Parameters.checkout_arcadia_from_url
        arcadia_patch = self.Parameters.arcadia_patch

        try:
            nirvana_token = sdk2.Vault.data(self.Parameters.nirvana_token_vault)
        except Exception as exc:
            eh.log_exception('Failed to get nirvana token from vault', exc)
            raise SandboxTaskFailureError('Fail on get token from vault storage: ' + str(exc))

        with self.memoize_stage.build_stage:
            subtasks = []

            ya_make_task_class = sdk2.Task["KOSHER_YA_MAKE"]
            build_run_metrics_bin_sub_task = ya_make_task_class(
                self,
                checkout_arcadia_from_url=arcadia_url,
                arcadia_patch=arcadia_patch,
                description='Build run/vh bin',
                result_rt=ARCADIA_PROJECT.name,
                targets='voicetech/asr/tools/regular_metrics_run/run/vh',
                arts='voicetech/asr/tools/regular_metrics_run/run/vh/run_metrics',
                build_type='release',
                result_single_file=True,
                checkout=True
            )
            self.Context.build_run_metrics_bin_sub_task_id = build_run_metrics_bin_sub_task.id
            build_run_metrics_bin_sub_task.enqueue()
            subtasks.append(build_run_metrics_bin_sub_task)

            raise sdk2.WaitTask(subtasks, asr_utils.DEFAULT_SUBTASK_WAIT_STATUS, wait_all=True)

        with self.memoize_stage.run_metrics:
            run_metrics_resource = sdk2.Resource.find(
                type=ARCADIA_PROJECT.name, task_id=self.Context.build_run_metrics_bin_sub_task_id).first()
            run_metrics_bin = sdk2.ResourceData(run_metrics_resource)

            run_metrics_stderr = "run_metrics_stderr"
            os.environ['NIRVANA_TOKEN'] = nirvana_token
            os.environ['YT_PROXY'] = self.Parameters.yt_proxy

            config_path = 'config.json'

            if self.Parameters.config_path:
                asr_utils.svn_export_file_or_dir(arcadia_url, self.Parameters.config_path, config_path)
            else:
                with open(config_path, 'w') as fp:
                    json.dump(self.Parameters.config_data, fp)

            with sdk2.helpers.ProcessLog(self, logger=logging.getLogger("run_metrics")) as pl:
                lingware_resource_type = self._RESOURCE_MAP[self.Parameters.lingware_type]
                args = [
                    str(run_metrics_bin.path),
                    config_path,
                    '--lingware', 'inplace_lingware:' + lingware_resource_type.arcadia_build_path,
                    '--checkout-arcadia-from-url', arcadia_url,
                ]
                if arcadia_patch:
                    args.extend(['--arcadia-patch', arcadia_patch])
                if getattr(lingware_resource_type, 'zen', False):
                    args.extend(['--zen'])

                cmd = ' '.join(args)
                logging.info('Run: {cmd}'.format(cmd=cmd))

                result_code = sp.Popen(
                    '{cmd} 2> {run_metrics_stderr}'.format(**locals()),
                    shell=True,
                    stdout=pl.stdout,
                ).wait()

                if os.path.isfile(run_metrics_stderr):
                    with open(run_metrics_stderr) as fp:
                        stderr_text = fp.read()
                else:
                    stderr_text = ''

                if result_code != 0:
                    logging.error(stderr_text)
                    raise SandboxTaskFailureError("got internal error on regular metrics run")

            workflow_instance_info = asr_utils.parse_workflow_instance_info(stderr_text)

            if workflow_instance_info is None:
                raise SandboxTaskFailureError("got internal error on regular metrics run")
            logging.info(workflow_instance_info)
            self.Context.workflow_instance_info = workflow_instance_info._asdict()

        workflow_instance_info = asr_utils.WorkflowInstanceInfo(**self.Context.workflow_instance_info)
        logging.info(workflow_instance_info)
        workflow = NirvanaWorkflow(nirvana_token, workflow_instance_info.workflow_id)
        limit_runs = self.Parameters.checks_limit
        wait_time = self.Parameters.checks_period

        def _success_callback():
            results = workflow.get_block_results(workflow_instance_info.instance_id, 'final_metrics', 'output')
            logging.info(results)
            result_resource = OTHER_RESOURCE(
                self,
                "Run metrics result",
                "metrics.json",
                ttl="inf",
            )
            with open(str(result_resource.path), "w") as fp:
                fp.write(results)

        def _fail_callback():
            logging.info("Graph failed: {}".format(workflow.gui_url(workflow_instance_info.instance_id)))
            raise SandboxTaskFailureError("got error on running regular metrics run")

        def _wait_callback(current_status):
            logging.info(current_status)
            self.set_info(current_status)

        asr_utils.wait_for_nirvana_results(
            self, workflow, workflow_instance_info.instance_id,
            limit_runs, wait_time,
            success_cb=_success_callback,
            fail_cb=_fail_callback,
            wait_cb=_wait_callback
        )
