# -*- coding: utf-8 -*-

import logging
import shutil

from sandbox.projects import resource_types
from sandbox.sandboxsdk.process import run_process
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.parameters import ResourceSelector, SandboxIntegerParameter, SandboxStringParameter


class MyMatrixnetExecutableParameter(ResourceSelector):
    name = 'matrixnet_resource_id'
    description = 'matrixnet binary'
    resource_type = 'MATRIXNET_EXECUTABLE'
    required = True


class MyMatrixnetTestPoolParameter(ResourceSelector):
    name = 'test_pool_resource_id'
    description = 'matrixnet test pool'
    resource_type = 'MATRIXNET_TEST_POOL'
    required = True


class MyMatrixnetSlavesCount(SandboxIntegerParameter):
    name = 'train_matrixnet_slaves_count'
    description = 'numbert of slaves to run'
    default_value = 1
    required = True


class MatrixnetMasterCLIParams(SandboxStringParameter):
    name = 'master_additional_params'
    description = 'Params to append to master cli command'


class MatrixnetSlaveCLIParams(SandboxStringParameter):
    name = 'slave_additional_params'
    description = 'Params to append to slave cli command'


def init_config(f):
    config = dict()
    for line in f:
        line = line.rstrip()
        pos = line.find('\t')
        if pos == -1 and line.startswith('learn command'):
            pos = len('learn command')
        if pos == -1:
            continue
        config[line[:pos]] = line[pos + 1:].strip()
    return config


class CalculateMatrixnetModelAndPredictions(SandboxTask):
    type = 'CALCULATE_MATRIXNET_MODEL_AND_PREDICTIONS'
    start_port = 11215

    input_parameters = (
        MyMatrixnetExecutableParameter,
        MyMatrixnetTestPoolParameter,
        MyMatrixnetSlavesCount,
    )

    def run_slaves(self):
        port = self.start_port
        self.slaves_processes = []
        with open('hosts.txt', 'w') as hosts:
            for i in xrange(self.ctx['train_matrixnet_slaves_count']):
                slave = run_process('./matrixnet -N  -r tsc -p {}'.format(port), log_prefix='learn', wait=False)
                self.slaves_processes.append(slave)
                hosts.write('localhost:{}\n'.format(port))
                port += 1

    def on_execute(self):
        logging.info('on_execute started...')
        shutil.copy(self.sync_resource(self.ctx['matrixnet_resource_id']), 'matrixnet')
        pool_archive_path = self.sync_resource(self.ctx['test_pool_resource_id'])
        run_process('tar zxf %s --strip-components=1' % pool_archive_path, log_prefix='extract')
        config = init_config(open('config.tsv', 'r'))
        if self.ctx['train_matrixnet_slaves_count'] > 0:
            self.run_slaves()
            config['learn command'] += ' -M -r tsc'
        run_process(config['learn command'], log_prefix='learn')
        run_process('./matrixnet -A -f learn.tsv -t test.tsv', log_prefix='test')
        matrixnet_predictions = self.create_resource(config['learn command'], 'test.tsv.matrixnet', resource_types.MATRIXNET_TESTING_PREDICTIONS)
        self.ctx['matrixnet_predictions_resource_id'] = matrixnet_predictions.id


__Task__ = CalculateMatrixnetModelAndPredictions
