# -*- coding: utf-8 -*-

import os.path
import json

from sandbox.projects import resource_types
from sandbox.projects.common import apihelpers
import sandbox.projects.yane.common as yane
from sandbox.projects.common.utils import check_subtasks_fails
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.paths import make_folder, copy_path, remove_path, add_write_permissions_for_path
from sandbox.sandboxsdk.parameters import LastReleasedResource, SandboxStringParameter, SandboxBoolGroupParameter, SandboxBoolParameter
from sandbox.sandboxsdk.process import run_process
from sandbox.sandboxsdk.svn import Arcadia
from sandbox import common
import sandbox.common.types.misc as ctm

NUMBER_OF_SPLITS_DEFAULT = 10
MODELS = ('query', 'text', 'video', 'yobject')
MODELS_DEFAULT = ('query', 'text', 'video')


class YaneTotalLearn(yane.YaneLearnTaskBase):
    """
        Learns all models
    """
    type = 'YANE_TOTAL_LEARN'

    execution_space = 60 * 1024

    class RussianYaneData(LastReleasedResource):
        name = 'russian_data'
        description = 'Russian data (leave blank to get from SVN)'
        resource_type = resource_types.YANE_DATA
        required = False

    class TurkishYaneData(LastReleasedResource):
        name = 'turkish_data'
        description = 'Turkish data (leave blank to get from SVN)'
        resource_type = resource_types.YANE_DATA
        required = False

    class LearnedLaguages(SandboxBoolGroupParameter):
        LANGUAGES = ('ru',)
        name = 'learned_languages'
        description = "Learned Languages"
        choices = [(l, l) for l in LANGUAGES]
        default_value = ' '.join(LANGUAGES)

    class LearnedModels(SandboxBoolGroupParameter):
        name = 'learned_models'
        description = "Learned Models"
        choices = [(m, m) for m in MODELS]
        default_value = ' '.join(MODELS_DEFAULT)

    class MatrixnetBinary(LastReleasedResource):
        name = 'matrixnet'
        description = 'Matrixnet binary (optional, leave blank to build from trunk)'
        resource_type = resource_types.MATRIXNET_EXECUTABLE
        required = False

    class EnableBorsches(SandboxBoolParameter):
        name = 'with_borsches'
        description = 'Turn on Borsches'
        default_value = False

    class DaysToKeepDataAsResource(SandboxStringParameter):
        name = "days_to_keep_data_as_resource"
        description = "Days to keep data as resource (inf for infinity)"
        default_value = "inf"
        required = True

    input_parameters = [RussianYaneData, TurkishYaneData, MatrixnetBinary, EnableBorsches, DaysToKeepDataAsResource, LearnedModels, LearnedLaguages] + \
        yane.get_base_params().params + yane.get_mr_params().params

    def get_selected_langs(self):
        return self.ctx[self.LearnedLaguages.name].strip().split(' ')

    def get_selected_models(self):
        return self.ctx[self.LearnedModels.name].strip().split(' ')

    def __init__(self, task_id=0):
        yane.YaneTaskBase.__init__(self, task_id)
        self.ctx['kill_timeout'] = 5 * 60 * 60

    def make_result_html(self):
        def make_table(body):
            return (
                '<table  class="main fixed-titles" style="width:100%;"><tbody>' +
                body +
                '</tbody></table>'
            )

        tbody = '<tr><td><b>Language, mode</b></td><td><b>Precision</b></td><td><b>Recall</b></td><td><b>F-measure</b></td><td><b>Profiling results</b></td></tr>\n'

        child_tasks = self.ctx['child_tasks']
        for child_id in child_tasks:
            task = channel.sandbox.get_task(child_id)
            profile_title = ''
            if task.ctx.get('profile_title'):
                profile_title = task.ctx['profile_title']

            row = '<tr>'
            row += '<td>%s, %s</td>' % (task.ctx['extraction_languages'], task.ctx['formula'])
            row += '<td>%s</td>' % task.ctx['precision']
            row += '<td>%s</td>' % task.ctx['recall']
            row += '<td>%s</td>' % task.ctx['fmeasure']
            row += '<td>%s</td>' % profile_title
            row += '</tr>\n'
            tbody += row

        if 'russian_data_disk_usage' in self.ctx:
            tbody += '<tr><td> </td></tr>'
            tbody += '<tr><td><b>russian_data_disk_usage</b></td><td><b>current</b></td><td><b>trunk</b></td></tr>'
            ru_disk_usage = json.loads(self.ctx['russian_data_disk_usage'])
            ru_disk_usage_trunk = json.loads(self.ctx['russian_data_trunk_disk_usage'])
            for key, value in ru_disk_usage.iteritems():
                tbody += '<tr>'
                tbody += '<td>%s</td><td>%s</td>' % (key, self.sizeof_fmt(value))
                if key in ru_disk_usage_trunk:
                    tbody += '<td>%s</td>' % self.sizeof_fmt(ru_disk_usage_trunk[key])
                tbody += '</tr>'
            total_size = sum(ru_disk_usage.values())
            total_trunk_size = sum(ru_disk_usage_trunk.values())
            tbody += '<tr><td>%s</td><td>%s</td><td>%s</td></tr>' % ('TOTAL:', self.sizeof_fmt(total_size), self.sizeof_fmt(total_trunk_size))

        return '<h4 class="task-section-header">Results </h4>' + make_table(tbody)

    @property
    def footer(self):
        return self.ctx.get('result_html')

    @staticmethod
    def __get_filter_threshold(lang, formula):
        if formula == 'query':
            if lang == 'ru':
                return -4.0
            else:
                return -5.0
        elif formula == 'text':
            if lang == 'ru':
                return -3.0
            else:
                return -4.0
        elif formula == 'yobject':
            return -5.0
        else:
            return None

    def __get_ttl_for_resource(self):
        if self.ctx[self.DaysToKeepDataAsResource.name] == "inf":
            return self.ctx[self.DaysToKeepDataAsResource.name]
        return int(self.ctx[self.DaysToKeepDataAsResource.name])

    def __get_filter_params(self, lang, formula):
        if formula != 'query' and formula != 'text' and formula != 'yobject':
            return {'filter_hypos': False}
        filter_threshold = self.__get_filter_threshold(lang, formula)
        iteration_num = 500
        if formula == 'yobject':
            iteration_num = 2000

        return ({
            'filter_hypos': True,
            'filter_matrixnet_options': ('-w 0.1 -S 0.75 -i %d' % iteration_num),
            'filter_threshold': filter_threshold})

    @yane.run_once
    def learn_models(self):
        if 'child_tasks' in self.ctx:  # проснулись после ожидания child_tasks
            return

        common_params = {
            'kill_timeout': max(4 * 60 * 60, self.ctx['kill_timeout'] - 60 * 60),
            'matrixnet': self.ctx['matrixnet'],
            'svn_url': self.ctx['svn_url'],
            'arcadia_patch': self.ctx.get('arcadia_patch'),
            'tools': self.ctx['tools'],
            'notify_via': '',
            'notify_if_finished': '',
            'notify_if_failed': self.author,
            'with_borsches': self.ctx['with_borsches']
        }

        languages = {
            'ru': {'data': self.ctx['russian_data']}
        }
        formulas = (
            {'formula': 'query', 'iters': 3000, 'matrixnet_options': '-w 0.05 -S 0.75'},
            {'formula': 'text', 'iters': 6000, 'matrixnet_options': '-w 0.05 -S 0.75'},
            {'formula': 'video', 'iters': 4000, 'matrixnet_options': '-w 0.05 -S 0.75', 'extraction_languages': 'ru'},
        )

        extra_params = {
            ('ru', 'text'): {'learn_significance_model': True},
            ('ru', 'yobject'): {'learn_significance_model': False, 'with_borsches': True}
        }

        child_tasks = []
        for lang, dict in languages.iteritems():
            if lang not in self.get_selected_langs():
                continue

            for formula in formulas:
                if formula['formula'] not in self.get_selected_models():
                    continue

                params = common_params.copy()
                if not (formula.get('extraction_languages', lang) == lang):
                    continue

                params['extraction_languages'] = lang
                params.update(dict)
                params.update(formula)
                params.update(self.__get_filter_params(lang, formula['formula']))

                extra_param = extra_params.get((lang, formula['formula']), {})
                params.update(extra_param)

                # params['iters'] = 10 # для отладки
                # params['splits'] = 2 # для отладки

                if common.config.Registry().common.installation == ctm.Installation.PRODUCTION:
                    subtasks_cpu_model = self.cpu_model_filter or 'e5-2650'  # for perfomance testing
                else:
                    subtasks_cpu_model = self.cpu_model_filter

                subtask = self.create_subtask('YANE_CALC_QUALITY',
                                              '%s, %s, %s' % (self.descr, lang, formula['formula']),
                                              params,
                                              model=subtasks_cpu_model)
                child_tasks.append(subtask.id)

        self.ctx['child_tasks'] = child_tasks
        self.wait_tasks(child_tasks, tuple(self.Status.Group.FINISH), True)

    @staticmethod
    def get_model_file_name(calc_quality_task):
        formula_top_prefix = {'query': 'query', 'text': 'doc', 'video': 'videoquery', 'yobject': 'yobject'}  # именование файлов в nerlib/models
        model_prefix = formula_top_prefix[calc_quality_task.ctx['formula']]
        lang_suffix = calc_quality_task.ctx['extraction_languages']
        return '%smodel%s.info' % (model_prefix, lang_suffix)

    @staticmethod
    def get_data_model_file_name(calc_quality_task):
        formula_top_prefix = {'query': 'query', 'text': 'doc', 'video': 'videoquery', 'yobject': 'yobject'}  # именование файлов в data/dict/ner/models
        model_prefix = formula_top_prefix[calc_quality_task.ctx['formula']]
        return '%smodel.info' % model_prefix

    def save_models_as_resource(self, res_name_prefix):
        check_subtasks_fails()
        tasks = channel.sandbox.list_tasks(task_type='YANE_CALC_QUALITY', parent_id=self.id)
        models_folder = make_folder(res_name_prefix)
        for task in tasks:
            model_id = apihelpers.get_task_resource_id(task.id, resource_types.YANE_MODEL)

            temp_folder = make_folder('temp')
            run_process(['tar', '-C', temp_folder, '-zxf', self.sync_resource(model_id), '--strip-components=1'], log_prefix='extract_pool')
            dest_file = self.get_model_file_name(task)
            copy_path(os.path.join(temp_folder, 'yane.info'), os.path.join(models_folder, dest_file))
            filter_model_file = os.path.join(temp_folder, 'filter_yane.info')
            if os.path.isfile(filter_model_file):
                copy_path(filter_model_file, os.path.join(models_folder, 'filter' + dest_file))
            significance_model_file = os.path.join(temp_folder, 'significance_yane.info')
            if os.path.isfile(significance_model_file):
                copy_path(significance_model_file, os.path.join(models_folder, 'significance' + dest_file))
            significance_ukr_model_file = os.path.join(temp_folder, 'significance_ukr_yane.info')
            if os.path.isfile(significance_ukr_model_file):
                copy_path(significance_ukr_model_file, os.path.join(models_folder, 'significanceukr' + dest_file))

            remove_path(temp_folder)

        models_tgz_name = '%s.tgz' % res_name_prefix
        run_process(['tar', '-czf', models_tgz_name, res_name_prefix], log_prefix='tar', work_dir=self.abs_path())

        res = self.create_resource(
            res_name_prefix + ': ' + self.descr,
            resource_path=models_tgz_name,
            resource_type=resource_types.OTHER_RESOURCE,
            attributes={'ttl': self.__get_ttl_for_resource(), 'backup_task': True}
        )
        self.mark_resource_ready(res.id)
        self.ctx[res_name_prefix] = res.id

    def sizeof_fmt(self, num, suffix='B'):
        for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
            if abs(num) < 1024.0:
                return "%3.1f%s%s" % (num, unit, suffix)
            num /= 1024.0
        return "%.1f%s%s" % (num, 'Yi', suffix)

    def get_data_disk_usage(self, data_dir):
        data_disk_usage = {}
        for file in os.listdir(data_dir):
            file_path = os.path.join(data_dir, os.path.basename(file))
            size = os.path.getsize(file_path)
            data_disk_usage[os.path.basename(file)] = size
        return json.dumps(data_disk_usage)

    def save_data_as_resource(self, lang, data_name):
        check_subtasks_fails()

        # download data
        data_dir = self.abs_path(data_name)
        if self.is_resource_selected(data_name):
            res_dir = self.sync_resource(self.ctx[data_name])
            copy_path(res_dir, data_dir)  # ресурс принадлежит чужой задаче, res_dir нельзя использовать в create_resource
        else:
            make_folder(data_dir)
            lang_suffix = ''
            if lang == 'tr':
                lang_suffix = lang
            svn_path = os.path.join('data/dict/ner', lang_suffix)
            Arcadia.checkout(self.get_svn_path(svn_path), path=data_dir, depth='files')

        # get models
        add_write_permissions_for_path(data_dir)  # неведомая ошибка (не разобрался), без этой строчки не дают копировать файлы в data_dir
        tasks = channel.sandbox.list_tasks(task_type='YANE_CALC_QUALITY', parent_id=self.id)
        for task in tasks:
            task = channel.sandbox.get_task(task.id)
            if not task.ctx['extraction_languages'] == lang:
                continue
            model_id = apihelpers.get_task_resource_id(task.id, resource_types.YANE_MODEL)

            temp_folder = make_folder('temp')
            run_process(['tar', '-C', temp_folder, '-zxf', self.sync_resource(model_id), '--strip-components=1'], log_prefix='extract_pool')
            dest_file = self.get_data_model_file_name(task)
            copy_path(os.path.join(temp_folder, 'yane.info'), os.path.join(data_dir, dest_file))

            filter_model_file = os.path.join(temp_folder, 'filter_yane.info')
            if os.path.isfile(filter_model_file):
                copy_path(filter_model_file, os.path.join(data_dir, 'filter' + dest_file))

            significance_model_file = os.path.join(temp_folder, 'significance_yane.info')
            if os.path.isfile(significance_model_file):
                copy_path(significance_model_file, os.path.join(data_dir, 'significancenewsmodel.info'))

            significance_ukr_model_file = os.path.join(temp_folder, 'significance_ukr_yane.info')
            if os.path.isfile(significance_ukr_model_file):
                copy_path(significance_ukr_model_file, os.path.join(data_dir, 'significancenewsukrmodel.info'))

            # dump ontodb info
            if 'ontodb_version' in self.ctx:
                with open(os.path.join(data_dir, 'ontodb_version.txt'), 'w') as version_file:
                    version_file.write(self.ctx['ontodb_version'] + '\n')

            with open(os.path.join(data_dir, 'total_learn_task_id.txt'), 'w') as f:
                f.write(str(self.id))

            remove_path(temp_folder)

        data_model_tgz_name = '%s_data_model.tgz' % data_name
        run_process(['tar', '-czf', data_model_tgz_name, '-C', data_dir, '.'], log_prefix='tar', work_dir=self.abs_path())

        # create data-model resource
        res = self.create_resource(
            data_name + '_model: ' + self.descr,
            resource_path=data_model_tgz_name,
            resource_type=resource_types.YANE_DATA,
            arch='any',
            attributes={'ttl': self.__get_ttl_for_resource(), 'backup_task': True}
        )
        self.mark_resource_ready(res.id)
        self.ctx[data_name + '_model'] = res.id
        # disk usage
        self.ctx[data_name + '_disk_usage'] = self.get_data_disk_usage(data_dir)
        remove_path(data_dir)
        if self.is_resource_selected(data_name):
            lang_suffix = ''
            if lang == 'tr':
                lang_suffix = lang
            data_dir = self.get_data_path(lang_suffix, True)  # checkout trunk data
            self.ctx[data_name + '_trunk_disk_usage'] = self.get_data_disk_usage(data_dir)
            remove_path(data_dir)
        else:
            self.ctx[data_name + '_trunk_disk_usage'] = self.ctx[data_name + '_disk_usage']

    def do_execute(self):
        self.ctx['ontodb_version'] = self.get_ontodb_version()

        self.build_matrixnet()
        self.learn_models()
        self.save_models_as_resource('linked_in_models')
        self.save_data_as_resource('ru', 'russian_data')

        self.ctx['result_html'] = self.make_result_html()


__Task__ = YaneTotalLearn
