# -*- coding: utf-8 -*-

import os
import os.path
import logging

import sandbox.common.types.client as ctc
from sandbox.common.errors import SandboxException, TaskFailure
from sandbox.projects import resource_types
from sandbox.projects.common import apihelpers
import sandbox.projects.common.build.parameters as build_params
import sandbox.projects.common.constants as consts
from sandbox.projects.common.utils import check_if_tasks_are_ok
from sandbox.projects.common import error_handlers as eh
from sandbox import sandboxsdk
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.errors import SandboxTaskFailureError, SandboxSubprocessError
from sandbox.sandboxsdk.parameters import SandboxArcadiaUrlParameter, SandboxBoolParameter, SandboxIntegerParameter, \
    LastReleasedResource, SandboxStringParameter, SandboxSelectParameter
from sandbox.sandboxsdk.paths import make_folder
from sandbox.sandboxsdk.process import run_process, check_process_return_code
from sandbox.sandboxsdk.svn import Arcadia
from sandbox.sandboxsdk.task import SandboxTask


GROUP_MR = 'Map-reduce parameters'
GROUP_OUT = 'Output data parameters'
GROUP_IN = 'Source data parameters'
GROUP_BASE = 'Basic YANE parameters'


def run_once(f):
    step_name = 'step_' + f.__name__

    def g(task, *args, **kwargs):
        if not task.ctx.get(step_name, False):
            with task.current_action("Running step: %s" % f.__name__):
                result = f(task, *args, **kwargs)
            task.ctx[step_name] = True
            return result
        else:
            logging.debug("Skipping already executed step: %s", step_name)

    return g


def check_param(task, ctx_name, condition):
    if ctx_name not in task.ctx:
        return None
    value = task.ctx[ctx_name]
    eh.ensure(condition(value), 'param %s is not valid: %s' % (ctx_name, value))
    return value


def validate_resource(res_id, descr, required, res_type, arch='any'):
    if not res_id:
        if required:
            raise TaskFailure("Resource for %s must be specified" % descr)
        else:
            return

    res = None
    try:
        res = channel.sandbox.get_resource(res_id)
    except:
        raise TaskFailure("Resource for %s has invalid id=%s" % (descr, res_id))
    if res is None:
        raise TaskFailure("Resource for %s with id=%s doesn't exist" % (descr, res_id))
    if not sandboxsdk.util.is_arch_compatible(res.arch, arch):
        raise TaskFailure("Incompatible %s architecture" % descr)
    if res.type != res_type:
        raise TaskFailure("Invalid resource type %s of %s" % (res.type, descr))


def get_base_params(group_name=GROUP_BASE, is_required=True):
    class YaneBaseParams:
        class SvnPath(SandboxArcadiaUrlParameter):
            name = 'svn_url'
            description = 'Svn url:'
            group = group_name
            default_value = 'arcadia:/arc/trunk'
            required = is_required

        class DoApplyArcPatch(SandboxBoolParameter):
            name = 'do_apply_patch'
            description = 'Apply patch'
            group = group_name
            default_value = False
            sub_fields = {'true': [consts.ARCADIA_PATCH_KEY]}

        class YaneTools(LastReleasedResource):
            name = 'tools'
            description = 'Yane tools (leave blank to build)'
            group = group_name
            resource_type = resource_types.YANE_TOOLS
            required = False
            do_not_copy = True

        class YaneArcadiaPatch(build_params.ArcadiaPatch):
            name = 'arcadia_patch'
            group = group_name
            do_not_copy = True

        params = [SvnPath, DoApplyArcPatch, YaneArcadiaPatch, YaneTools]

    return YaneBaseParams


def get_text_params(group_name=None, is_required=True):
    class YaneTextParams:
        class TextsPool(LastReleasedResource):
            name = 'texts'
            description = 'Text pool'
            group = group_name
            resource_type = resource_types.YANE_TSV
            required = is_required

        class IdealMarkup(LastReleasedResource):
            name = 'markup'
            description = 'Ideal markup'
            group = group_name
            resource_type = resource_types.YANE_TSV
            required = is_required

        class IdsTrie(LastReleasedResource):
            name = 'ids_trie'
            description = 'External Ids Trie'
            group = group_name
            resource_type = resource_types.OTHER_RESOURCE
            required = is_required

        params = [TextsPool, IdealMarkup, IdsTrie]

    return YaneTextParams


def get_mr_params(group_name=GROUP_MR):
    class YaneMRParams:
        class Server(SandboxStringParameter):
            name = 'mr_server'
            description = 'Server'
            default_value = 'hahn.yt.yandex.net'
            group = group_name
            required = True

        class Runtime(SandboxSelectParameter):
            name = 'mr_runtime'
            description = 'Runtime'
            group = group_name
            default_value = 'YT'
            choices = [('YT', 'YT'), ('MR', 'MR')]
            required = True

        class User(SandboxStringParameter):
            name = 'mr_user'
            description = 'User'
            group = group_name
            default_value = 'dict'
            required = True

        params = [Server, Runtime, User]

    return YaneMRParams


class YaneTaskBase(SandboxTask):
    """
        Base Yane task
    """

    class Config(dict):
        def __init__(self, src):
            with open(src) as f:
                try:
                    cfg = eval(f.read())
                    dict.update(self, cfg)
                except BaseException as e:
                    raise TaskFailure('Bad format of YANE_CONFIG resource: {}'.format(e))

    def get_config(self, ctx_name):
        config = YaneTaskBase.Config(self.sync_resource(self.ctx[ctx_name]))
        logging.debug("config: \n%s", config)
        return config

    def check_tasks(self, ctx_name, task_descr):
        if ctx_name not in self.ctx:
            raise SandboxTaskFailureError('Cannot find subtask for %s' % task_descr)
        tasks = self.ctx[ctx_name]
        if isinstance(tasks, (list, tuple)):
            check_if_tasks_are_ok(tasks)
        else:
            check_if_tasks_are_ok([tasks])

        return tasks

    def get_svn_path(self, subpath):
        url_parts = Arcadia.parse_url(self.ctx['svn_url'])
        logging.debug("svn_url parts: %s", url_parts)
        target_url = Arcadia.replace(self.ctx['svn_url'], path=os.path.join(url_parts.path, subpath),
                                     revision=url_parts.revision)
        logging.debug("target_url: %s", target_url)
        return target_url

    def is_resource_selected(self, ctx_name):
        return bool(self.ctx.get(ctx_name))

    def get_tool(self, executable, tools_param='tools'):
        if not self.is_resource_selected(tools_param):
            raise SandboxTaskFailureError("Resource '%s' for Yane tools is not specified" % tools_param)

        tool_path = os.path.join(self.sync_resource(self.ctx[tools_param]), executable)
        if not os.path.exists(tool_path):
            raise SandboxTaskFailureError("Resource #%s does not have '%s'" % (self.ctx[tools_param], executable))
        return tool_path

    def run_tool(self, executable, args, env=None, wait=True, tools_param='tools'):
        cmd = [self.get_tool(executable, tools_param)] + args
        logging.debug("Command to run: %s", cmd)
        try:
            return run_process(cmd, log_prefix=executable, wait=wait, environment=env)
        except SandboxSubprocessError as exc:
            self.set_info(exc.get_task_info(), do_escape=False)
            # Raise exception to allow task restart
            raise SandboxException(exc.message)

    def build_matrixnet(self, param_name='matrixnet'):  # собрать matrixnet, если он не указан как ресурс
        if not self.is_resource_selected(param_name):
            if 'matrixnet_task_id' not in self.ctx:
                subtask = self.create_subtask('BUILD_MATRIXNET', 'build_matrixnet', arch=self.client_info['arch'])
                self.ctx['matrixnet_task_id'] = subtask.id
                self.wait_task_completed(subtask)
            resources = apihelpers.list_task_resources(self.ctx['matrixnet_task_id'],
                                                       resource_types.MATRIXNET_EXECUTABLE)
            self.ctx[param_name] = resources[0].id
        else:
            matrixnet_arch = channel.sandbox.get_resource(self.ctx['matrixnet']).arch
            if not matrixnet_arch == self.client_info['arch']:
                raise SandboxException(
                    'matrixnet platform %s differs from task arch %s' % (matrixnet_arch, self.client_info['arch']))

    def wait_processes(self, processes):
        for p in processes:
            if p.poll() is None:
                with self.current_action('Waiting process "{}"'.format(p.saved_cmd)):
                    p.wait()
            try:
                check_process_return_code(p)
            except SandboxSubprocessError as exc:
                self.set_info(exc.get_task_info(), do_escape=False)
                # Raise exception to allow task restart
                raise SandboxException(exc.message)

    def get_mr_env(self):
        env = os.environ.copy()
        env['MR_RUNTIME'] = self.ctx['mr_runtime']
        if 'mr_user' in self.ctx:
            env['MR_USER'] = self.ctx['mr_user']
        if self.ctx['mr_runtime'] == 'YT':
            try:
                env['YT_TOKEN'] = self.get_vault_data('YANE', 'robot_yane_yt_token')
                if 'mr_server' in self.ctx:
                    env['YT_PROXY'] = self.ctx['mr_server']
            except TaskFailure as exc:
                # Raise exception to allow task restart
                raise SandboxException(str(exc))
        return env

    def compile_gzt(self, gzt_header, src_table, gzt_name, local_additional_gzt=None):
        Arcadia.export(self.get_svn_path(gzt_header), path=self.abs_path())
        Arcadia.export(self.get_svn_path('arcadia/dict/nerlib/proto'), path='proto')
        Arcadia.export(self.get_svn_path('arcadia/dict/nerlib/gztproto'), path='proto')
        packed_gzt_name = gzt_name + '.gzt.gz'

        args = ['-s', self.ctx['mr_server'],
                '-t', src_table,
                '-f', '%v\\n',
                '-h', os.path.basename(gzt_header),
                '-o', packed_gzt_name,
                '-c'
                ]

        if local_additional_gzt:
            args.extend(['-a', local_additional_gzt])

        self.run_tool('mr_dump', args, self.get_mr_env())

        self.run_tool('gztcompiler', ['-I', 'proto', packed_gzt_name, gzt_name + '.bin'])

    def compile_trie(self, src_table, trie_name):
        self.run_tool('idstriecreator', ['-s', self.ctx['mr_server'], '-i', src_table, '-o', trie_name],
                      self.get_mr_env())

    def get_ontodb_version(self):
        ver_dir = make_folder('version_dir')
        ontobd_ver_out = os.path.join(ver_dir, 'ontobd_ver.txt')
        ontobd_ver_tool = self.get_tool('ontodb_ver')
        run_process(ontobd_ver_tool,
                    log_prefix='ontobd_ver',
                    environment=self.get_mr_env(),
                    wait=True,
                    stdout=open(ontobd_ver_out, 'w'),
                    outputs_to_one_file=False)

        with open(ontobd_ver_out) as ontobd_ver_out_file:
            for line in ontobd_ver_out_file:
                return line.strip()

    def download_proto_data(self, target, proto_type, sources, useQueryFreq=True, cache=None,
                            max_related_obj_count=None,
                            prefix='', human_readable_input=False, wait=True):
        table_list = os.path.join('tmp', '{}.tables'.format(proto_type))
        with open(table_list, 'w') as f:
            for src in sources:
                f.write(os.path.join(src[0], '{}{}.data\t{}'.format(prefix, proto_type, src[1])))
                f.write('\n')
        args = ['-s', self.ctx['mr_server'],
                '-i', table_list,
                '-o', target,
                '-t', proto_type,
                ]
        if cache and proto_type != 'metaobject':
            args.extend(['-C', repr(cache)])
        if not useQueryFreq:
            args.append('-D')
        if max_related_obj_count is not None and proto_type == 'object':
            args.extend(['-R', repr(max_related_obj_count)])
        if human_readable_input:
            args.append('-H')

        return self.run_tool('protostoragecreator', args, self.get_mr_env(), wait)

    client_tags = ctc.Tag.Group.LINUX

    def do_execute(self):
        raise NotImplementedError('Method do_execute is not implemented in Sandbox task %s' % self.type)

    def on_execute(self):
        with self.memoize_stage.first_run:
            self._on_first_run()

        if not self.is_resource_selected('tools'):
            if 'tools_task_id' in self.ctx:
                task_id = self.check_tasks('tools_task_id', 'Build Yane tools')
                self.ctx['tools'] = apihelpers.get_task_resource_id(task_id, resource_types.YANE_TOOLS,
                                                                    arch=self.client_info['arch'])
                logging.debug("Using %s tools resource from %s subtask", self.ctx['tools'], task_id)
            else:
                rev = Arcadia.get_revision(self.ctx['svn_url'])
                descr = 'Yane tools'
                if rev:
                    descr += "@{}".format(rev)

                params = {
                    'kill_timeout': 2 * 60 * 60,
                    consts.ARCADIA_URL_KEY: self.get_svn_path('arcadia'),
                    consts.ARCADIA_PATCH_KEY: self.ctx.get(consts.ARCADIA_PATCH_KEY),
                    consts.BUILD_SYSTEM_KEY: consts.YMAKE_BUILD_SYSTEM,
                    'notify_via': '',
                    'notify_if_finished': '',
                    'notify_if_failed': self.author
                }

                subtask = self.create_subtask('YANE_BUILD_TOOLS',
                                              descr,
                                              input_parameters=params,
                                              arch=self.client_info['arch'])
                self.ctx['tools_task_id'] = subtask.id
                logging.debug("Tools are not specified. Creating %s subtask to build them", subtask.id)
                self.wait_all_tasks_stop_executing([subtask.id])
        self.do_execute()

    def create_subtask(
        self,
        task_type,
        description,
        input_parameters=None,
        host=None,
        model=None,
        arch=None,
        priority=None,
        important=False,
        execution_space=None,
        inherit_notifications=False,
        tags=None,
        se_tag=None,
        enqueue=True,
        ram=None
    ):
        """
        Mimics trunk/arcadia/sandbox/sandboxsdk/task.py create_subtask logic,
        except ignoring inherit_notifications and instead removing SUCCESS notifications
        from subprocesses.
        """
        context = self.ctx if input_parameters is None else input_parameters
        context.setdefault('GSID', self.ctx.get('__GSID', ''))
        context.setdefault('tasks_archive_resource', self.tasks_archive_resource)
        if arch is None:
            arch = self.arch
        if model is None:
            model = self.model
        if host is None:
            host = self.required_host
        if isinstance(priority, (tuple, list)):
            priority = self.Priority().__setstate__(priority)
        if not isinstance(priority, self.Priority) or priority > self.priority:
            priority = self.priority

        params = {}
        if hasattr(self, 'notifications'):
            params['notifications'] = self.notifications
            for notif_rule in params['notifications']:
                while u'SUCCESS' in notif_rule['statuses']: notif_rule['statuses'].remove(u'SUCCESS')
        if tags is not None:
            params['tags'] = tags
        if se_tag is not None:
            params['se_tag'] = se_tag
        if ram is not None:
            params['ram'] = ram
        params = params or None
        task = channel.sandbox.create_task(
            task_type=task_type, description=description, owner=self.owner,
            context=context, parent_task_id=self.id, model=model,
            host=host, arch=arch, priority=priority.__getstate__(), important=important,
            execution_space=execution_space, parameters=params, enqueue=enqueue
        )
        logging.info(
            "Sub-task #%d (type: '%s', host: '%s', model: '%s', arch: '%s', priority: %s) created. Success notifications are turned off because of YaneTaskBase code.",
            task.id, task_type, host, model, arch, priority
        )
        return task

    def on_enqueue(self):
        SandboxTask.on_enqueue(self)
        validate_resource(self.ctx.get('tools'), 'Yane tools', False,
                          resource_types.YANE_TOOLS, self.arch)

    def _on_first_run(self):
        # Freeze revision
        self.ctx['svn_url'] = Arcadia.freeze_url_revision(self.ctx['svn_url'])
        logging.debug("svn_url after freeze: %s", self.ctx['svn_url'])
        # Remember assigned arch
        self.ctx['assigned_arch'] = self.client_info['arch']
        logging.debug("assigned_arch: %s", self.ctx['assigned_arch'])


class YaneLearnTaskBase(YaneTaskBase):
    """
        Base Yane learn task
    """

    class NumberOfIterations(SandboxIntegerParameter):
        name = 'iters'
        description = 'Number of iterations'
        default_value = 8000
        required = True

        @classmethod
        def cast(cls, value):
            value = super(YaneLearnTaskBase.NumberOfIterations, cls).cast(value)
            if value is not None and value < 0:
                raise ValueError("Negative value {!r}".format(value))
            return value

    class NumberOfCPU(SandboxIntegerParameter):
        name = 'min_ncpu'
        description = 'Minimum number of CPUs'
        default_value = 0
        required = True

        @classmethod
        def cast(cls, value):
            value = super(YaneLearnTaskBase.NumberOfCPU, cls).cast(value)
            if value is not None and value < 0:
                raise ValueError("Negative value {!r}".format(value))
            return value

    class AdditionalMXOptions(SandboxStringParameter):
        name = 'matrixnet_options'
        description = 'Additional matrixnet options'
        required = False

    def __checkout_or_sync_data(self, data_name, svn_path, _depth='files', force_trunk=False):
        if self.is_resource_selected(data_name) and not force_trunk:
            return self.sync_resource(self.ctx[data_name])
        else:
            data_dir = self.abs_path(data_name)
            if not os.path.exists(data_dir):  # повторный вызов checkout
                data_dir = make_folder(data_dir)
                Arcadia.checkout(self.get_svn_path(svn_path), path=data_dir, depth=_depth)
            return data_dir

    def get_data_path(self, dir_suffix, force_trunk=False):
        data_name = 'data'
        svn_path = os.path.join('data/dict/ner', dir_suffix)
        data_dir = self.__checkout_or_sync_data(data_name, svn_path, 'immediates', force_trunk)

        if force_trunk or not self.is_resource_selected(data_name):  # если данные пришли из svn, надо забрать оттуда еще wizdata
            rev = Arcadia.get_revision(self.get_svn_path(svn_path))
            Arcadia.update(os.path.join(data_dir, 'wizdata'), set_depth='infinity', revision=rev)

        svn_path = os.path.join('data/dict/ner', dir_suffix)

        # for testing lda
        self.__checkout_or_sync_data("lda", "arcadia_tests_data/wizard/reqtopicsclassifier", force_trunk=force_trunk)

        return data_dir

    def get_text_pool(self, dir_suffix):
        data_name = 'texts'
        svn_path = os.path.join('data/dict/ner/texts', dir_suffix)
        data_path = self.__checkout_or_sync_data(data_name, svn_path, 'infinity')
        if os.path.isfile(data_path):  # данные из svn приходят в виде директории, из ресурса - сразу файл
            return data_path
        if self.ctx.get('formula') == 'query':
            data_path = os.path.join(data_path, 'query.texts.tsv')
        elif self.ctx.get('formula') == 'video':
            data_path = os.path.join(data_path, 'video', self.ctx['extraction_languages'], 'texts.tsv')
        elif self.ctx.get('formula') == 'music':
            data_path = os.path.join(data_path, 'texts.tsv')
        elif self.ctx.get('formula') == 'yobject':
            data_path = os.path.join(data_path, 'texts.tsv')
        else:
            data_path = os.path.join(data_path, 'doc.texts.tsv')
        return data_path

    def get_markup(self, dir_suffix):
        data_name = 'markup'
        svn_path = os.path.join('data/dict/ner/pools', dir_suffix)
        if self.ctx.get('formula') == 'query':
            svn_path = os.path.join(svn_path, 'query')
        elif self.ctx.get('formula') == 'video':
            svn_path = os.path.join(svn_path, 'video', self.ctx['extraction_languages'])
        elif self.ctx.get('formula') == 'music':
            pass
        elif self.ctx.get('formula') == 'yobject':
            pass
        else:
            svn_path = os.path.join(svn_path, 'texts')
        data_path = self.__checkout_or_sync_data(data_name, svn_path)
        if os.path.isfile(data_path):  # данные из svn приходят в виде директории, из ресурса - сразу файл
            return data_path
        data_path = os.path.join(data_path, 'idealmarkup.tsv')
        return data_path

    def get_ids_trie(self, dir_suffix):
        data_path = self.get_data_path(dir_suffix)
        return os.path.join(data_path, 'object.ids.trie')

    def get_extractor_cmd(self, force_trunk=False):
        self.ctx['dir_suffix'] = ''
        extraction_languages = self.ctx.get('extraction_languages')

        if self.ctx.get('formula') == 'music':
            self.ctx['dir_suffix'] = 'music'  # Игнорируем выбор языка для музыкальной формулы
        elif self.ctx.get('formula') == 'yobject':
            self.ctx['dir_suffix'] = 'yobject'
        elif extraction_languages == "tr":
            self.ctx['dir_suffix'] = 'tr'

        extractor_cmd = [self.get_tool('objectsextractor'), '-d', self.get_data_path(self.ctx['dir_suffix'], force_trunk)]
        if self.ctx.get('formula') == 'query':
            extractor_cmd.append('-Q')
        elif self.ctx.get('formula') == 'video':
            extractor_cmd.append('--video-mode')
        elif self.ctx.get('formula') == 'music':
            extractor_cmd.append('--music-mode')

        disabled_features = self.ctx.get('disable_features', '')
        if disabled_features and disabled_features != "":
            extractor_cmd.extend(['-D', disabled_features])

        if self.ctx['formula'] != 'music':
            extractor_cmd.extend(['-l', ",".join([extraction_languages, "en"])])

        if self.ctx.get('with_borsches'):
            extractor_cmd.append('--borsch')
        return extractor_cmd
