import logging

from sandbox.projects.factorization.tools.params import add_common_params_to_command
from sandbox.projects.factorization.tools.params import AlsCommonParams
from sandbox.projects.factorization.tools.params import FactorizationAlsColumnDescriptor
from sandbox.projects.factorization.tools.params import FactorizationAlsData
from sandbox.projects.factorization.tools.params import FactorizationAlsFeatureNames
from sandbox.projects.factorization.tools.params import FactorizationAlsModel

from sandbox import sdk2
from sandbox.sdk2.helpers import subprocess as sp


class FactorizationAlsLocalExecutable(sdk2.Resource):
    executable = True
    releasable = True
    releasers = ['sokursky', 'roizner']


class FactorizationAlsLocal(sdk2.Task):
    ''' Runner tool for ALS local factorization '''

    class Parameters(AlsCommonParams):
        als_local_executable = sdk2.parameters.Resource(
            'ALS local executable',
            resource_type=FactorizationAlsLocalExecutable,
            required=True,
        )

        train_data = sdk2.parameters.Resource('Train data', resource_type=FactorizationAlsData)
        test_data = sdk2.parameters.Resource('Test data', resource_type=FactorizationAlsData)
        input_model = sdk2.parameters.Resource('Input model', resource_type=FactorizationAlsModel)
        column_descriptor = sdk2.parameters.Resource('Column descriptor', resource_type=FactorizationAlsColumnDescriptor)

    def on_execute(self):
        if self.Parameters.input_model:
            unpack_command = ['tar', '-xzvf', str(sdk2.ResourceData(self.Parameters.input_model).path)]

            logging.info('Run command: {}'.format(' '.join(unpack_command)))
            with sdk2.helpers.ProcessLog(self, logger='unpack_executor') as l:
                sp.check_call(unpack_command, stdout=l.stdout, stderr=l.stderr)

        model_resource = FactorizationAlsModel(self, self.Parameters.description + ' learned model', 'model_out')
        model_resource_data = sdk2.ResourceData(model_resource)

        command = [
            str(sdk2.ResourceData(self.Parameters.als_local_executable).path),
            'general',
            '--mo', './model',
            '--lko', './lkeys',
            '--rko', './rkeys',
        ]

        if self.Parameters.train_data:
            command += [
                '-f', str(sdk2.ResourceData(self.Parameters.train_data).path),
                '--cd', str(sdk2.ResourceData(self.Parameters.column_descriptor).path),
            ]
        if self.Parameters.test_data:
            test_resource = FactorizationAlsData(self, self.Parameters.description + ' predictions', 'test_out')
            test_resource_data = sdk2.ResourceData(test_resource)

            command += [
                '-t', str(sdk2.ResourceData(self.Parameters.test_data).path),
                '-o', str(test_resource_data.path),
            ]

        if self.Parameters.input_model:
            command += [
                '--mi', './model',
                '--lki', './lkeys',
                '--rki', './rkeys',
            ]
        if self.Parameters.features:
            fnames_resource = FactorizationAlsFeatureNames(self, self.Parameters.description + ' feature names', 'feature_names.tsv')
            fnames_resource_data = sdk2.ResourceData(fnames_resource)

            command += ['--features', '--fnames', str(fnames_resource_data.path)]
            if self.Parameters.feature_prefix:
                command += ['--fprefix', self.Parameters.feature_prefix]

        command = add_common_params_to_command(self, command)

        logging.info('Run command: {}'.format(' '.join(command)))
        with sdk2.helpers.ProcessLog(self, logger='als_local_executor') as l:
            sp.check_call(command, stdout=l.stdout, stderr=l.stderr)

        pack_command = ['tar', '-czvf', str(model_resource_data.path), './model', './lkeys', './rkeys']
        logging.info('Run command: {}'.format(' '.join(pack_command)))
        with sdk2.helpers.ProcessLog(self, logger='pack_executor') as l:
            sp.check_call(pack_command, stdout=l.stdout, stderr=l.stderr)

        model_resource_data.ready()
        if self.Parameters.test_data:
            test_resource_data.ready()
        if self.Parameters.features:
            fnames_resource_data.ready()
