# -*- coding: utf-8 -*-

import logging
import os
import os.path

from sandbox.projects import resource_types
import sandbox.projects.yane.common as yane
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.sandboxsdk.parameters import LastReleasedResource, SandboxStringParameter
from sandbox.sandboxsdk.paths import make_folder, copy_path
from sandbox.sandboxsdk.process import run_process
from sandbox.projects.yane.common import check_param
from sandbox.sandboxsdk.svn import Arcadia


SIGNIFICANCE_TEXTS_AND_MARKUP_PATH = 'data/dict/ner/pools/significance/news_markup'
SIGNIFICANCE_TEXTS_AND_MARKUP_UKR_PATH = 'data/dict/ner/pools/significance/news_markup_ukr'


class YaneLearnModel(yane.YaneLearnTaskBase):
    """
        Learn matrixnet model of object extractor. Expects input YANE_POOL resource of the following format:

        tar+gzip file/
            <folder with any name>/
                learn.tsv (required)

                test.tsv (optional)

        Outputs YANE_MODEL resource with the following content:

        model.tgz/
            model/
                yane.bin

                yane.fpair

                yane.fstr

                yane.inc

                yane.inc.factors

                yane.info

                learn.tsv.matrixnet

                test.tsv.test.matrixnet (presents if source pool has test.tsv)
    """
    type = 'YANE_LEARN_MODEL'

    execution_space = 30 * 1024

    class MatrixnetBinary(LastReleasedResource):
        name = 'matrixnet'
        description = 'Matrixnet binary'
        resource_type = resource_types.MATRIXNET_EXECUTABLE

    class LearnPool(LastReleasedResource):
        name = 'pool'
        description = 'Learn pool'
        resource_type = resource_types.YANE_POOL

    class ModelType(SandboxStringParameter):
        name = 'model_type'
        description = 'Model type'
        required = True
        default_value = 'classification'
        choices = [
            ('classification', 'classification'),
            ('regression', 'regression')
        ]

    input_parameters = [
        MatrixnetBinary,
        LearnPool,
        yane.YaneLearnTaskBase.NumberOfIterations,
        yane.YaneLearnTaskBase.NumberOfCPU,
        yane.YaneLearnTaskBase.AdditionalMXOptions] + \
        yane.get_text_params('Pools (optional)', False).params

    def __run_matrixnet(self, mn_path, model_dir, pool_dir, prefix='', iter_count=None, classification=True):
        if not os.path.exists(os.path.join(pool_dir, prefix + 'learn.tsv')):
            raise SandboxTaskFailureError(
                'Invalid pool resource. Missing ' + prefix + 'learn.tsv file'
            )

        mn_cmd = [mn_path, '-f', os.path.join(pool_dir, prefix + 'learn.tsv')]
        if os.path.exists(os.path.join(pool_dir, prefix + 'test.tsv')):
            mn_cmd.extend(['-t', os.path.join(pool_dir, prefix + 'test.tsv')])
        mn_cmd.extend(['-o', os.path.join(model_dir, prefix + 'yane')])

        if classification:
            mn_cmd.append('-c')

        if iter_count is not None:
            mn_cmd.extend(['-i', repr(iter_count)])
        threads = int(self.client_info['ncpu'])
        # 8 is default number of threads in matrixnet
        if threads > 8:
            mn_cmd.extend(['-T', repr(threads - 2)])
        if prefix + 'matrixnet_options' in self.ctx:
            mn_cmd.extend(self.ctx[prefix + 'matrixnet_options'].split())
        mn_cmd.append('--minEntropy')

        run_process(mn_cmd, log_prefix='matrixnet')

    def __filter_hypos(self, mn_path, model_dir, pool_dir):
        self.__run_matrixnet(mn_path, model_dir, pool_dir, 'filter_')

        extractor_cmd = self.get_extractor_cmd()

        texts_path = self.get_text_pool(self.ctx['dir_suffix'])
        markup_path = self.get_markup(self.ctx['dir_suffix'])
        oe_dir = make_folder('oe')
        features = os.path.join(oe_dir, 'features.txt')
        filter_model_file_name = os.path.join(model_dir, 'filter_yane.info')

        extractor_cmd.append('--filter-features')
        extractor_cmd.extend(['--filter-model', filter_model_file_name])
        extractor_cmd.extend(['--filter-threshold', repr(float(self.ctx['filter_threshold'])-1.0)])

        logging.debug("extractor_cmd: %s", extractor_cmd)
        run_process(extractor_cmd,
                    log_prefix='objectsextractor',
                    wait=True,
                    stdin=open(texts_path, 'r'),
                    stdout=open(features, 'w'),
                    outputs_to_one_file=False)

        with self.current_action('Creating pool from features'):
            learn = os.path.join(pool_dir, 'learn.tsv')
            create_pool_cmd = [self.get_tool('pooltools'),
                               '-m', markup_path,
                               '-f', features,
                               '-t', self.get_ids_trie(self.ctx['dir_suffix']),
                               '--filter-hypos', os.path.join(pool_dir, 'filter_learn.tsv')]
            run_process(create_pool_cmd,
                        log_prefix='pooltools',
                        wait=True,
                        stdout=open(learn, 'w'),
                        outputs_to_one_file=False)

            if os.path.exists(os.path.join(pool_dir, 'filter_test.tsv')):
                test = os.path.join(pool_dir, 'test.tsv')
                create_pool_cmd = [self.get_tool('pooltools'),
                                   '-m', markup_path,
                                   '-f', features,
                                   '-t', self.get_ids_trie(self.ctx['dir_suffix']),
                                   '--filter-hypos', os.path.join(pool_dir, 'filter_test.tsv')]
                run_process(create_pool_cmd,
                            log_prefix='pooltools',
                            wait=True,
                            stdout=open(test, 'w'),
                            outputs_to_one_file=False)

    def __learn_significance(self, mn_path, model_dir, pool_dir, prefix, svn_data_path):
        extractor_cmd = self.get_extractor_cmd()

        data_path = make_folder(prefix + 'data')
        Arcadia.checkout(self.get_svn_path(svn_data_path), path=data_path)
        texts_path = os.path.join(data_path, 'texts.tsv')
        markup_path = os.path.join(data_path, 'markup.tsv')

        oe_dir = make_folder('oe')
        features = os.path.join(oe_dir, prefix + 'features.txt')
        filter_model_file_name = os.path.join(model_dir, 'filter_yane.info')
        model_file_name = os.path.join(model_dir, 'yane.info')

        extractor_cmd.append('--print-significance-features')
        extractor_cmd.extend(['--model', model_file_name])
        if self.ctx.get('filter_hypos'):
            extractor_cmd.extend(['--filter-model', filter_model_file_name])
            extractor_cmd.extend(['--filter-threshold', repr(self.ctx['filter_threshold'])])

        logging.debug("extractor_cmd: %s", extractor_cmd)
        run_process(extractor_cmd,
                    log_prefix='objectsextractor',
                    wait=True,
                    stdin=open(texts_path, 'r'),
                    stdout=open(features, 'w'),
                    outputs_to_one_file=False)

        with self.current_action('Creating pool from features'):
            learn = os.path.join(pool_dir, prefix + 'learn.tsv')
            create_pool_cmd = [self.get_tool('pooltools'),
                               '--sign',
                               '-m', markup_path,
                               '-f', features,
                               '-t', self.get_ids_trie(self.ctx['dir_suffix']),
                               '--filter-hypos', os.path.join(pool_dir, 'learn.tsv')]
            run_process(create_pool_cmd,
                        log_prefix='pooltools',
                        wait=True,
                        stdout=open(learn, 'w'),
                        outputs_to_one_file=False)

            if os.path.exists(os.path.join(pool_dir, 'test.tsv')):
                test = os.path.join(pool_dir, prefix + 'test.tsv')
                create_pool_cmd = [self.get_tool('pooltools'),
                                   '--sign',
                                   '-m', markup_path,
                                   '-f', features,
                                   '-t', self.get_ids_trie(self.ctx['dir_suffix']),
                                   '--filter-hypos', os.path.join(pool_dir, 'test.tsv')]
                run_process(create_pool_cmd,
                            log_prefix='pooltools',
                            wait=True,
                            stdout=open(test, 'w'),
                            outputs_to_one_file=False)

            if os.path.exists(os.path.join(pool_dir, 'filter_test.tsv')):
                test = os.path.join(pool_dir, prefix + 'filter_test.tsv')
                create_pool_cmd = [self.get_tool('pooltools'),
                                   '--sign',
                                   '-m', markup_path,
                                   '-f', os.path.join(pool_dir, 'filter_test.tsv'),
                                   '-t', self.get_ids_trie(self.ctx['dir_suffix'])]
                run_process(create_pool_cmd,
                            log_prefix='pooltools',
                            wait=True,
                            stdout=open(test, 'w'),
                            outputs_to_one_file=False)

        self.__run_matrixnet(mn_path, model_dir, pool_dir, prefix, self.ctx['significance_iters'], False)

    def on_execute(self):
        mn_path = self.sync_resource(self.ctx['matrixnet'])
        pool_dir = make_folder(self.abs_path('pool'))
        model_dir = make_folder(self.abs_path('model'))
        run_process(['tar', '-C', pool_dir, '-zxf', self.sync_resource(self.ctx['pool']), '--strip-components=1'],
                    log_prefix='extract_pool')

        if self.ctx.get('filter_hypos'):
            self.__filter_hypos(mn_path, model_dir, pool_dir)

        iter_count = int(check_param(self, 'iters', lambda param: int(param) > 0))
        self.__run_matrixnet(mn_path, model_dir, pool_dir, '', iter_count)

        if self.ctx.get('learn_significance_model'):
            self.__learn_significance(mn_path, model_dir, pool_dir,
                                      'significance_', SIGNIFICANCE_TEXTS_AND_MARKUP_PATH)
            self.__learn_significance(mn_path, model_dir, pool_dir,
                                      'significance_ukr_', SIGNIFICANCE_TEXTS_AND_MARKUP_UKR_PATH)

        for f in os.listdir(pool_dir):
            if f.endswith('.matrixnet') or f.endswith('filter_test.tsv'):
                copy_path(os.path.join(pool_dir, f), os.path.join(model_dir, f))

        run_process(['tar', '-czf', 'model.tgz', 'model'], log_prefix='tar', work_dir=self.abs_path())

        self.create_resource(
            self.descr + ' (model)',
            resource_path='model.tgz',
            resource_type=resource_types.YANE_MODEL,
        )


__Task__ = YaneLearnModel
