# -*- coding: utf-8 -*-
import os
import shutil
import logging
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.parameters import ResourceSelector, SandboxStringParameter
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.projects import resource_types
from sandbox.sandboxsdk.process import run_process
from sandbox.projects.common import apihelpers


class MatrixnetExecutableParameter(ResourceSelector):
    name = 'matrixnet_resource_id'
    description = 'matrixnet binary'
    resource_type = 'MATRIXNET_EXECUTABLE'
    required = True


class MatrixnetTestPoolIdsParameter(SandboxStringParameter):
    name = 'matrixnet_test_pool_ids'
    description = 'comma separated list of MATRIXNET_TEST_POOL ids'
    required = True


class ValidateMatrixnetModels(SandboxTask):
    """
    **Описание**
    Сравение равенства результатов предсказания матрикснет моделей inc, bin и info на нескольких пулах, в разных режимах обучения.
    Сначала происходит сборка бинарников Matrixnet, MxOps, Infotest,
    затем эти программы используются для обучения, конвертации моделей
    и сравнения результатов обучения для каждого из указанных пулов.
    В случае успешного завершения дочерних тасков, результаты агрегируются и записываются в ресурс MATRIXNET_ALLTESTS_RESULTS

    **Необходимые ресурсы и параметры**

    * **comma separated list of MATRIXNET_TEST_POOL ids** - список ресурсов MATRIXNET_TEST_POOL, для каждого из которых будет запущен дочерний таск TEST_MATRIXNET
    * **Svn url for arcadia** - svn url из которого будут собираться программы

    **Создаваемые ресурсы**

    * MATRIXNET_ALLTESTS_RESULTS
    """
    type = 'VALIDATE_MATRIXNET_MODELS'

    input_parameters = (
        MatrixnetTestPoolIdsParameter,
        MatrixnetExecutableParameter
    )

    @staticmethod
    def _get_resource_id(task_id, resource_type, operation_name='Build'):
        task = channel.sandbox.get_task(task_id)
        if task.is_failure():
            raise SandboxTaskFailureError('{} in sub-task {} failed'.format(operation_name, task_id))

        resources = [r for r in apihelpers.list_task_resources(task_id) if r.type == resource_type]
        if len(resources) != 1:
            raise SandboxTaskFailureError("unknown error")

        return resources[0].id

    def _run_tests(self):
        pool_ids = self.ctx['matrixnet_test_pool_ids'].split(',')
        tests_data = dict()

        for poolId in pool_ids:
            test_key = 'matrixnet_test_with_pool_' + poolId

            tests_data[test_key] = (self.arch, poolId, test_key)
            logging.info('test_key: {}    tests_data: {}'.format(test_key, tests_data[test_key]))

        wait_tests = []
        for test_key in tests_data:
            if not self.ctx.get(test_key):
                logging.info('test_key: {}    tests_data: {}'.format(test_key, tests_data[test_key]))
                self.ctx[test_key] = self._schedule_test(*tests_data[test_key])
                wait_tests.append(self.ctx[test_key])
            else:
                task = channel.sandbox.get_task(self.ctx[test_key])
                if not task.is_done():
                    wait_tests.append(self.ctx[test_key])

        if wait_tests:
            self.wait_all_tasks_completed(wait_tests)

        os.mkdir('alltests')
        for test_key in tests_data:
            resource_id = self._get_resource_id(self.ctx[test_key], 'MATRIXNET_TESTING_PREDICTIONS', 'Matrixnet run')
            shutil.copy(self.sync_resource(resource_id), 'alltests/%s_predictions.tsv' % test_key)
        resource = channel.sandbox.get_resource(self.ctx['matrixnet_alltests_results_resource_id'])
        cmd = 'tar zcf %s alltests' % resource.path
        run_process(cmd, log_prefix='compress')

    def on_enqueue(self):
        SandboxTask.on_enqueue(self)
        self.ctx['matrixnet_alltests_results_resource_id'] = self._create_resource(
            'tests results',
            'alltests.tar.gz',
            resource_types.MATRIXNET_ALLTESTS_RESULTS, arch=self.arch
        ).id

    def on_execute(self):
        self._run_tests()

    def _schedule_test(self, arch, pool_id, description):
        logging.info('description ' + description)
        sub_ctx = {
            'test_pool_resource_id': pool_id,
            'matrixnet_resource_id': self.ctx['matrixnet_resource_id'],
        }
        task = self.create_subtask(task_type='TEST_MATRIXNET', description=description,
                                   input_parameters=sub_ctx, arch=arch, important=self.important)
        return task.id


__Task__ = ValidateMatrixnetModels
