package ru.yandex.crypta.lab.job;

import java.time.Duration;

import javax.inject.Inject;

import org.quartz.JobExecutionContext;

import ru.yandex.crypta.clients.sandbox.SandboxClient;
import ru.yandex.crypta.clients.sandbox.entity.CustomField;
import ru.yandex.crypta.clients.sandbox.entity.Description;
import ru.yandex.crypta.clients.sandbox.entity.Owner;
import ru.yandex.crypta.clients.sandbox.entity.Priority;
import ru.yandex.crypta.clients.sandbox.entity.Tag;
import ru.yandex.crypta.clients.sandbox.entity.TaskDraft;
import ru.yandex.crypta.clients.sandbox.entity.Type;
import ru.yandex.crypta.common.ws.auth.ServiceSecurityContext;
import ru.yandex.crypta.lab.LabService;
import ru.yandex.crypta.lib.python.custom_ml.proto.ETrainingSampleState;

public class ModelTrainingJob extends SandboxJob {
    public static final String SAMPLE_ID = "sample_id";
    public static final String TABLE_PATH = "sample_table_path";
    public static final String FILE_PATH = "sample_file_path";
    public static final String AUDIENCE_ID = "audience_id";
    public static final String POSITIVE_SEGMENT_SIZE = "positive_segment_size";
    public static final String NEGATIVE_SEGMENT_SIZE = "negative_segment_size";
    public static final String TTL = "ttl";
    public static final String OUTPUT_DIR_PATH = "output_dir_path";
    public static final String MODEL_NAME = "model_name";
    public static final String PARTNER = "partner";
    public static final String LOGIN = "login";
    public static final String LOGINS_TO_SHARE = "logins_to_share";

    @Inject
    public ModelTrainingJob(SandboxClient sandbox) {
        super(sandbox);
    }

    @Inject
    private LabService lab;

    protected LabService lab() {
        return lab.withSecurityContext(new ServiceSecurityContext());
    }

    @Override
    protected int getInitialRetriesCount() {
        return 2;
    }

    @Override
    protected TaskDraft taskDraft(JobExecutionContext context) {
        String sampleId = getArgument(context, SAMPLE_ID);
        String rawSampleTablePath = getArgument(context, TABLE_PATH);
        String rawSampleFilePath = getArgument(context, FILE_PATH);
        String audienceId = getArgument(context, AUDIENCE_ID);
        String outputDirPath = getArgument(context, OUTPUT_DIR_PATH);
        String positiveSegmentSize = getArgument(context, POSITIVE_SEGMENT_SIZE);
        String negativeSegmentSize = getArgument(context, NEGATIVE_SEGMENT_SIZE);
        String modelName = getArgument(context, MODEL_NAME);
        String partner = getArgument(context, PARTNER);
        String login = getArgument(context, LOGIN);
        String loginsToShare = getArgument(context, LOGINS_TO_SHARE);
        String environment = getEnvironment(context);

        TaskDraft draft = createTaskDraft();
        draft.setType(new Type("CRYPTA_ML_TRAIN_CUSTOM_MODEL_TASK"));
        draft.setOwner(new Owner("CRYPTA"));
        draft.setPriority(new Priority("SERVICE", "NORMAL"));
        draft.setKillTimeout(Duration.ofHours(12).getSeconds());

        draft.addCustomField(new CustomField("environment", environment));
        draft.addCustomField(new CustomField(TABLE_PATH, rawSampleTablePath));
        if (rawSampleFilePath != null) {
            draft.addCustomField(new CustomField(FILE_PATH, rawSampleFilePath));
        }
        if (audienceId != null) {
            draft.addCustomField(new CustomField(AUDIENCE_ID, audienceId));
        }
        draft.addCustomField(new CustomField(POSITIVE_SEGMENT_SIZE,  positiveSegmentSize));
        draft.addCustomField(new CustomField(NEGATIVE_SEGMENT_SIZE,  negativeSegmentSize));
        draft.addCustomField(new CustomField(OUTPUT_DIR_PATH,  outputDirPath));
        if (modelName != null && partner != null) {
            draft.addCustomField(new CustomField(MODEL_NAME,  modelName));
            draft.addCustomField(new CustomField(PARTNER,  partner));
            draft.addCustomField(new CustomField("if_make_decision", true));
            draft.addCustomField(new CustomField(LOGIN, login));
            if (loginsToShare != null) {
                draft.addCustomField(new CustomField(LOGINS_TO_SHARE, loginsToShare));
            }
        }
        draft.addCustomField(new CustomField("send_results_to_api",  true));

        draft.addTag(new Tag(sampleId));
        draft.setDescription(new Description("Train custom model in lab"));

        return draft;
    }

    @Override
    protected void onFailure(JobExecutionContext context) {
        super.onFailure(context);
        lab().trainingSamples().updateSampleState(getArgument(context, SAMPLE_ID), ETrainingSampleState.ERROR);
    }
}
