import logging
import os

from sandbox.projects.factorization.tools.params import add_common_params_to_command
from sandbox.projects.factorization.tools.params import AlsCommonParams
from sandbox.projects.factorization.tools.params import FactorizationAlsFeatureNames

from sandbox import sdk2
from sandbox.sdk2.helpers import subprocess as sp


class FactorizationAlsYtExecutable(sdk2.Resource):
    executable = True
    releasable = True
    releasers = ['sokursky', 'roizner']


class FactorizationAlsYt(sdk2.Task):
    ''' Runner tool for ALS yt factorization '''

    class Parameters(AlsCommonParams):
        als_yt_executable = sdk2.parameters.Resource(
            'ALS YT executable',
            resource_type=FactorizationAlsYtExecutable,
        )
        yt_proxy = sdk2.parameters.String('YT proxy', default='hahn.yt.yandex.net')
        yt_token_secret_name = sdk2.parameters.String('YT token vault secret', required=True)

        train_data = sdk2.parameters.String('Train data table')
        test_data = sdk2.parameters.String('Test data table')
        test_out = sdk2.parameters.String('Test output table')

        modeli = sdk2.parameters.String('Input model table')
        lkeysi = sdk2.parameters.String('Input left keys table')
        rkeysi = sdk2.parameters.String('Input right keys table')

        modelo = sdk2.parameters.String('Output model table', required=True)
        lkeyso = sdk2.parameters.String('Output left keys table', required=True)
        rkeyso = sdk2.parameters.String('Output right keys table', required=True)

        lkeysf = sdk2.parameters.String('Freeze left keys table')
        rkeysf = sdk2.parameters.String('Freeze right keys table')

        max_data_size_per_job = sdk2.parameters.Integer("Max data size per job")
        skip_key_size_threshold = sdk2.parameters.Integer("Skip key size threshold")

    def on_execute(self):
        yt_token = sdk2.Vault.data(self.Parameters.yt_token_secret_name)
        assert yt_token is not None
        os.environ['MR_RUNTIME'] = 'YT'
        os.environ['YT_TOKEN'] = yt_token

        command = [
            str(sdk2.ResourceData(self.Parameters.als_yt_executable).path),
            'general',
            '--server', self.Parameters.yt_proxy,
            '--mo', self.Parameters.modelo,
            '--lko', self.Parameters.lkeyso,
            '--rko', self.Parameters.rkeyso,
        ]

        if self.Parameters.train_data:
            command += ['-f', self.Parameters.train_data]
        if self.Parameters.test_data:
            command += ['-t', self.Parameters.test_data]
        if self.Parameters.test_out:
            command += ['-o', self.Parameters.test_out]
        if self.Parameters.modeli:
            command += ['--mi', self.Parameters.modeli]
        if self.Parameters.lkeysi:
            command += ['--lki', self.Parameters.lkeysi]
        if self.Parameters.rkeysi:
            command += ['--rki', self.Parameters.rkeysi]
        if self.Parameters.lkeysf:
            command += ['--lkf', self.Parameters.lkeysf]
        if self.Parameters.rkeysf:
            command += ['--rkf', self.Parameters.rkeysf]
        if self.Parameters.features:
            fnames_resource = FactorizationAlsFeatureNames(self, self.Parameters.description + ' feature names', 'feature_names.tsv')
            fnames_resource_data = sdk2.ResourceData(fnames_resource)

            command += ['--features', '--fnames', str(fnames_resource_data.path)]
            if self.Parameters.feature_prefix:
                command += ['--fprefix', self.Parameters.feature_prefix]

        command = add_common_params_to_command(self, command)

        if self.Parameters.max_data_size_per_job:
            command += ['--max-data-size-per-job', str(self.Parameters.max_data_size_per_job)]
        if self.Parameters.skip_key_size_threshold:
            command += ['--skip-key-size-threshold', str(self.Parameters.skip_key_size_threshold)]

        logging.info('Run command: {}'.format(' '.join(command)))
        with sdk2.helpers.ProcessLog(self, logger='als_yt_executor') as l:
            sp.check_call(command, stdout=l.stdout, stderr=l.stderr)

        if self.Parameters.features:
            fnames_resource_data.ready()
