# -*- coding: utf-8 -*-

import logging
import os
import shutil

from sandbox.projects import resource_types
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.parameters import ResourceSelector, SandboxIntegerParameter
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.projects.common import apihelpers


class MyMatrixnetExecutableParameter(ResourceSelector):
    name = 'matrixnet_resource_id'
    description = 'matrixnet binary'
    resource_type = 'MATRIXNET_EXECUTABLE'
    required = True


class MyMatrixnetTestPoolParameter(ResourceSelector):
    name = 'test_pool_resource_id'
    description = 'matrixnet test pool'
    resource_type = 'MATRIXNET_TEST_POOL'
    required = True


class MyMatrixnetSlavesCount(SandboxIntegerParameter):
    name = 'train_matrixnet_slaves_count'
    description = 'numbert of slaves to run'
    default_value = 1
    required = True


class MyMatrixnetTestRuns(SandboxIntegerParameter):
    name = 'train_matrixnet_test_runs'
    description = 'numbert of test runs'
    default_value = 5


class CalculateMatrixnetPredictionsMean(SandboxTask):
    type = 'CALCULATE_MATRIXNET_PREDICTIONS_MEAN'
    input_parameters = (
        MyMatrixnetExecutableParameter,
        MyMatrixnetTestPoolParameter,
        MyMatrixnetSlavesCount,
        MyMatrixnetTestRuns,
    )

    @staticmethod
    def _get_resource_id(task_id, resource_type):
        task = channel.sandbox.get_task(task_id)
        if task.is_failure():
            raise SandboxTaskFailureError('Build in sub-task %s failed' % task_id)

        resources = [r for r in apihelpers.list_task_resources(task_id) if r.type == resource_type]
        if len(resources) != 1:
            raise SandboxTaskFailureError("unknown error")

        return resources[0].id

    def on_execute(self):
        logging.info('on_execute started...')

        wait_tests = []
        tests_keys = []
        for i in xrange(self.ctx['train_matrixnet_test_runs']):
            test_key = "subtask_" + str(i)
            tests_keys.append(test_key)
            if not self.ctx.get(test_key):
                logging.info('starting subtask # {}'.format(i))
                sub_ctx = {
                    'train_matrixnet_slaves_count': self.ctx['train_matrixnet_slaves_count'],
                    'test_pool_resource_id': self.ctx['test_pool_resource_id'],
                    'matrixnet_resource_id': self.ctx['matrixnet_resource_id'],
                }
                task = self.create_subtask(task_type='CALCULATE_MATRIXNET_MODEL_AND_PREDICTIONS',
                                           description="matrixnet calc subtask #{} for test pool {}".format(i, self.ctx['test_pool_resource_id']),
                                           input_parameters=sub_ctx,
                                           arch=self.arch,
                                           important=self.important)
                self.ctx[test_key] = task.id
                wait_tests.append(self.ctx[test_key])
            else:
                task = channel.sandbox.get_task(self.ctx[test_key])
                if not task.is_done():
                    wait_tests.append(self.ctx[test_key])

        if wait_tests:
            self.wait_all_tasks_completed(wait_tests)

        os.mkdir('results')
        for test_key in tests_keys:
            resource_id = self._get_resource_id(self.ctx[test_key], 'MATRIXNET_TESTING_PREDICTIONS')
            shutil.copy(self.sync_resource(resource_id), 'results/%s_predictions.tsv' % test_key)
        predictions_files = []
        import numpy
        try:
            for test_key in tests_keys:
                predictions_files.append(open('results/%s_predictions.tsv' % test_key, 'r'))
            with open('test.tsv.matrixnet', 'w') as out:
                while True:
                    predictions = []
                    for pred_f in predictions_files:
                        line = pred_f.readline().strip()
                        if line:
                            sp = line.split('\t')
                            predictions.append(float(sp[-1]))
                    if not predictions:
                        break

                    out.write('\t'.join(sp[:-1]) + '\t' + str(numpy.mean(predictions)) + '\t' + str(numpy.std(predictions, ddof=1)) + '\n')

            matrixnet_predictions = self.create_resource('mean .test values', 'test.tsv.matrixnet',
                                                         resource_types.MATRIXNET_TESTING_PREDICTIONS)
            self.ctx['matrixnet_predictions_resource_id'] = matrixnet_predictions.id
        finally:
            for f in predictions_files:
                f.close()


__Task__ = CalculateMatrixnetPredictionsMean
