# -*- coding: utf-8 -*-
import os
import shutil

import logging

import sandbox.common.types.client as ctc

from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.parameters import ResourceSelector, SandboxStringParameter
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.projects import resource_types
from sandbox.sandboxsdk.process import run_process

from sandbox.projects.common import apihelpers


MATRIXNET_PARAMS_GROUP = "Matrixnet testing params"


class MatrixnetExecutableParameter(ResourceSelector):
    name = 'matrixnet_resource_id'
    description = 'matrixnet binary'
    resource_type = 'MATRIXNET_EXECUTABLE'
    required = True


class MatrixnetTestPoolIdsParameter(SandboxStringParameter):
    name = 'matrixnet_test_pool_ids'
    description = 'comma separated list of MATRIXNET_TEST_POOL ids'
    required = True
    group = MATRIXNET_PARAMS_GROUP


class MatrixnetTestInternalFilesPairs(SandboxStringParameter):
    name = 'matrixnet_test_files_pairs'
    description = '''string containing list of pairs of files to test inside first archive in format filename1;filename2\nfilename3;filename4'''
    default_value = ''
    multiline = True
    required = False
    group = MATRIXNET_PARAMS_GROUP


class CalculateMatrixnetModelsAndPredictions(SandboxTask):
    """
    **Описание**
    Головной тест для проверки сходимости результатов обучения моделей в распределенном режиме.
    Сначала происходит сборка Matrixnet, затем запускаются дочерние таски, в которых для каждого тестового пула обучается N моделей, далее с использованием этих моделей считается средние и дисперсия предсказаний на тесте.
    Данные кладутся архивом в ресурс MATRIXNET_ALLTESTS_RESULTS

    **Необходимые ресурсы и параметры**

    * **comma separated list of MATRIXNET_TEST_POOL ids** - список ресурсов MATRIXNET_TEST_POOL, для каждого из которых будет запущен дочерний таск TEST_MATRIXNET
    * **Svn url for arcadia** - svn url из которого будут собираться программы

    **Создаваемые ресурсы**

    * MATRIXNET_ALLTESTS_RESULTS
    """
    type = 'CALCULATE_MATRIXNET_MODELS_AND_PREDICTIONS'
    client_tags = ctc.Tag.Group.LINUX
    input_parameters = (
        MatrixnetExecutableParameter,
        MatrixnetTestPoolIdsParameter,
        MatrixnetTestInternalFilesPairs,
    )

    @staticmethod
    def _get_resource_id(task_id, resource_type):
        task = channel.sandbox.get_task(task_id)
        if task.is_failure():
            raise SandboxTaskFailureError('Build in sub-task %s failed' % task_id)

        resources = [r for r in apihelpers.list_task_resources(task_id) if r.type == resource_type]
        if len(resources) != 1:
            raise SandboxTaskFailureError("unknown error")

        return resources[0].id

    def _run_tests(self):
        if self.ctx.get('matrixnet_alltests_results_ready'):
            print "data is ready"
            return
        pool_ids = self.ctx['matrixnet_test_pool_ids'].split(',')
        tests_data = dict()

        for poolId in pool_ids:
            test_key = 'matrixnet_test_with_pool_' + poolId
            tests_data[test_key] = (self.arch, poolId, test_key)
            logging.info('test_key: {}    tests_data: {}'.format(test_key, tests_data[test_key]))

        wait_tests = []
        for test_key in tests_data:
            if not self.ctx.get(test_key):
                logging.info('test_key: {}    tests_data: {} '.format(test_key, tests_data[test_key]))
                self.ctx[test_key] = self._schedule_test(*tests_data[test_key])
                wait_tests.append(self.ctx[test_key])
            else:
                task = channel.sandbox.get_task(self.ctx[test_key])
                if not task.is_done():
                    wait_tests.append(self.ctx[test_key])

        if wait_tests:
            self.wait_all_tasks_completed(wait_tests)

        os.mkdir('alltests')
        for test_key in tests_data:
            resource_id = self._get_resource_id(self.ctx[test_key], 'MATRIXNET_TESTING_PREDICTIONS')
            shutil.copy(self.sync_resource(resource_id), 'alltests/%s_predictions.tsv' % test_key)
        resource = channel.sandbox.get_resource(self.ctx['matrixnet_alltests_results_resource_id'])
        cmd = ['tar', 'zcf', resource.path, 'alltests']
        run_process(cmd, log_prefix='compress')
        self.mark_resource_ready(self.ctx['matrixnet_alltests_results_resource_id'])
        self.ctx['matrixnet_alltests_results_ready'] = True

    def on_enqueue(self):
        SandboxTask.on_enqueue(self)
        self.ctx['matrixnet_alltests_results_resource_id'] = self._create_resource(
            'matrixnet tests results',
            'alltests.tar.gz',
            resource_types.MATRIXNET_ALLTESTS_RESULTS,
            arch=self.arch).id

    def _run_internal_diff_test(self):
        internal_diff_test_key = 'internal_diff_test_key'
        wait_for_task = None
        if self.ctx.get(MatrixnetTestInternalFilesPairs.name):
            if not self.ctx.get(internal_diff_test_key):
                params = {'matrixnet_first_test_resource_id': self.ctx['matrixnet_alltests_results_resource_id'],
                          'matrixnet_test_files_pairs': self.ctx['matrixnet_test_files_pairs']
                          }
                taskid = self.create_subtask(task_type='CHECK_MATRIXNET_PREDICTIONS_DIFF',
                                             description='check cross-modes predictions difference',
                                             input_parameters=params,
                                             arch=self.arch,
                                             important=self.important).id
                self.ctx[internal_diff_test_key] = taskid
                wait_for_task = taskid
            else:
                task = channel.sandbox.get_task(self.ctx[internal_diff_test_key])
                if not task.is_done():
                    wait_for_task = self.ctx[internal_diff_test_key]
            if wait_for_task:
                self.wait_task_completed(wait_for_task)
            task = channel.sandbox.get_task(self.ctx[internal_diff_test_key])
            if task.ctx['mx_diff']:
                raise SandboxTaskFailureError("Has diff in pair files")

    def on_execute(self):
        self._run_tests()
        self._run_internal_diff_test()

    def _schedule_test(self, arch, pool_id, description):
        logging.info('description ' + description)
        sub_ctx = {
            'test_pool_resource_id': pool_id,
            'matrixnet_resource_id': self.ctx['matrixnet_resource_id'],
            'train_matrixnet_slaves_count': 1
        }
        task = self.create_subtask(task_type='CALCULATE_MATRIXNET_PREDICTIONS_MEAN', description=description,
                                   input_parameters=sub_ctx, arch=arch, important=self.important)
        return task.id


__Task__ = CalculateMatrixnetModelsAndPredictions
