from sandbox import sdk2
from sandbox.sdk2.parameters import LastReleasedResource, String, RadioGroup, Bool, Integer
import sandbox.common.types.resource as ctr

from sandbox.projects.autobudget.ml import AutobudgetMlBinary, AutobudgetMlTask


class AutobudgetMlPredictBinary(AutobudgetMlBinary):
    pass


class AutobudgetMlPredict(AutobudgetMlTask):
    class Parameters(sdk2.Parameters):
        binary_id = LastReleasedResource(
            "autobudget ml: 'predict' binary",
            resource_type=AutobudgetMlPredictBinary,
            state=(ctr.State.READY,),
            required=True
        )

        src_features_table = String('Source table with features', required=True)
        dst_predictions_table = String('Destination table with predictions', required=True)
        apc_model_name = String('APC model name', required=True)
        with RadioGroup('APC model type') as apc_model_type:
            apc_model_type.values['vw'] = apc_model_type.Value('vw', default=True)
            apc_model_type.values['mx'] = apc_model_type.Value('mx')
        ppc_model_name = String('PPC model name')
        with RadioGroup('PPC model type') as ppc_model_type:
            ppc_model_type.values['vw'] = ppc_model_type.Value('vw', default=True)
            ppc_model_type.values['mx'] = ppc_model_type.Value('mx')
        keep_prediction_log = Bool('Keep prediction log')
        prediction_log_dir = String("Prediction log directory")
        prediction_log_ttl_hours = Integer("Prediction log TTL in hours", default=24*7)
        yt_cluster = String("YT cluster", default='hahn')

    def on_execute(self):
        args = [
            ('--src-features-table', self.Parameters.src_features_table),
            ('--dst-predictions-table', self.Parameters.dst_predictions_table),
            ('--apc-model-name', self.Parameters.apc_model_name),
            ('--apc-model-type', self.Parameters.apc_model_type),
            ('--ppc-model-name', self.Parameters.ppc_model_name),
            ('--ppc-model-type', self.Parameters.ppc_model_type),
            ('--yt-cluster', self.Parameters.yt_cluster),
        ]

        if self.Parameters.keep_prediction_log:
            args.append('--keep-prediction-log')
            args.append(('--prediction-log-dir', self.Parameters.prediction_log_dir))
            args.append(('--prediction-log-ttl-hours', self.Parameters.prediction_log_ttl_hours))

        self.run_binary_with_args(args)
