package ru.yandex.crypta.api.rest.resource.lab;

import java.io.InputStream;
import java.util.Arrays;
import java.util.List;

import javax.inject.Inject;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.core.MediaType;

import io.swagger.annotations.ApiImplicitParam;
import io.swagger.annotations.ApiImplicitParams;
import io.swagger.annotations.ApiOperation;
import io.swagger.annotations.ApiParam;
import org.glassfish.jersey.media.multipart.FormDataParam;

import ru.yandex.crypta.common.exception.NotFoundException;
import ru.yandex.crypta.common.ws.jersey.JsonUtf8;
import ru.yandex.crypta.lab.LabService;
import ru.yandex.crypta.lab.proto.AccessLevel;
import ru.yandex.crypta.lab.proto.ETrainableSegmentOriginType;
import ru.yandex.crypta.lab.proto.ETrainableSegmentTargetType;
import ru.yandex.crypta.lab.proto.TTrainableSegment;
import ru.yandex.crypta.lab.proto.TTrainingSample;
import ru.yandex.crypta.lab.proto.TrainableSegments;
import ru.yandex.crypta.lib.python.custom_ml.proto.ETrainingSampleState;
import ru.yandex.crypta.lib.python.custom_ml.proto.TIndustry;
import ru.yandex.crypta.lib.python.custom_ml.proto.TMetrics;
import ru.yandex.crypta.lib.python.custom_ml.proto.TTrainingError;

@Produces(JsonUtf8.MEDIA_TYPE)
@Consumes(JsonUtf8.MEDIA_TYPE)
public class TrainingSampleResource extends CommonLabResource {

    @Inject
    public TrainingSampleResource(LabService lab) {
        super(lab);
    }

    @GET
    @ApiOperation(value = "Retrieve user training samples")
    public List<TTrainingSample> getTrainingSamples() {
        return lab().trainingSamples().getAll();
    }

    @GET
    @Path("{id}")
    @ApiOperation(value = "Retrieve training sample")
    public TTrainingSample getTrainingSample(@PathParam("id") @NotNull String id) {
        return lab().trainingSamples().getSample(id);
    }

    @POST
    @Path("create_sample")
    @ApiOperation(value = "Create training sample from table or audience")
    public TTrainingSample createTrainingSampleExtended(
            @ApiParam("Path to the sample, e.g. //tmp/sample") @QueryParam("path") String path,
            @ApiParam("Audience id with training sample") @QueryParam("audienceSegment") String audienceSegment,
            @ApiParam("Name of the sample") @QueryParam("name") String name,
            @ApiParam("Access level that controls scope of the sample") @QueryParam("accessLevel") AccessLevel accessLevel,
            @ApiParam("TTL of tables (in seconds)") @QueryParam("ttl") Long ttl,
            @ApiParam("Number of users in positive class segment") @QueryParam("positiveSegmentSize") Long positiveSegmentSize,
            @ApiParam("Number of users in negative class segment") @QueryParam("negativeSegmentSize") Long negativeSegmentSize,
            @ApiParam("Industry model name") @QueryParam("modelName") String modelName,
            @ApiParam("The partner who provided the data") @QueryParam("partner") String partner,
            @ApiParam("Logins to add access to segments to") @QueryParam("loginsToShare") String loginsToShare
    ) {
        return lab().trainingSamples().createSample(
                path, null, audienceSegment, name, accessLevel, ttl, positiveSegmentSize, negativeSegmentSize,
                modelName, partner, loginsToShare
        );
    }

    @POST
    @Path("create_sample_from_file")
    @Consumes(MediaType.MULTIPART_FORM_DATA)
    @ApiOperation(value = "Create training sample from file")
    @ApiImplicitParams({@ApiImplicitParam(name = "file", dataType = "java.io.File", paramType = "form")})
    public TTrainingSample createTrainingSampleFromFile(
            @ApiParam(hidden = true) @FormDataParam("file") InputStream uploadedInputStream,
            @ApiParam("Name of the sample") @QueryParam("name") String name,
            @ApiParam("Access level that controls scope of the sample") @QueryParam("accessLevel") AccessLevel accessLevel,
            @ApiParam("TTL of tables (in seconds)") @QueryParam("ttl") Long ttl,
            @ApiParam("Number of users in positive class segment") @QueryParam("positiveSegmentSize") Long positiveSegmentSize,
            @ApiParam("Number of users in negative class segment") @QueryParam("negativeSegmentSize") Long negativeSegmentSize,
            @ApiParam("Industry model name") @QueryParam("modelName") String modelName,
            @ApiParam("The partner who provided the data") @QueryParam("partner") String partner,
            @ApiParam("Logins to add access to segments to") @QueryParam("loginsToShare") String loginsToShare
            ) {
        return lab().trainingSamples().createSample(
                null, uploadedInputStream, null, name, accessLevel, ttl, positiveSegmentSize, negativeSegmentSize,
                modelName, partner, loginsToShare
        );
    }

    @DELETE
    @Path("{id}")
    @ApiOperation(value = "Delete training sample")
    public TTrainingSample deleteTrainingSample(@PathParam("id") @NotNull String id) {
        return lab().trainingSamples().deleteSample(id);
    }

    @DELETE
    @Path("delete_outdated_training_samples")
    @ApiOperation(value = "Delete outdated training samples")
    public List<TTrainingSample> deleteOutdatedTrainingSample() {
        return lab().trainingSamples().deleteOutdatedSamples();
    }

    @GET
    @Path("{id}/metrics")
    @ApiOperation(value = "Retrieve model metrics for a given training sample")
    public TMetrics getTrainingSampleMetrics(@PathParam("id") @NotNull String id) throws NotFoundException {
        return lab().trainingSamples().getMetrics(id);
    }

    @GET
    @Path("{id}/metrics_yt")
    @ApiOperation(value = "Retrieve model metrics for a given training sample from yt")
    public TMetrics getTrainingSampleMetricsFromYt(@PathParam("id") @NotNull String id) {
        return lab().trainingSamples().getSampleMetricsFromYt(id);
    }

    @GET
    @Path("{id}/segments")
    @ApiOperation(value = "Retrieve user training samples")
    public List<TTrainableSegment> getTrainableSegments(@PathParam("id") @NotNull String id) {
        return lab().trainingSamples().getSegmentsDetails(id);
    }

    @GET
    @Path("{id}/segments_ids")
    @ApiOperation(value = "Retrieve user training samples")
    public TrainableSegments getTrainableSegmentsIds(@PathParam("id") @NotNull String id) {
        return lab().trainingSamples().getSegments(id);
    }

    @POST
    @Path("write_metrics")
    @ApiOperation(value = "Write training sample metrics")
    public void writeTrainingSampleMetrics(
            @ApiParam("Training Sample Id") @QueryParam("sampleId") String sampleId,
            @ApiParam("ROC-AUC") @QueryParam("rocAuc") Double rocAuc,
            @ApiParam("Accuracy") @QueryParam("accuracy") Double accuracy,
            @ApiParam("Positive class ratio in training sample") @QueryParam("positiveClassRatio") Double positiveClassRatio,
            @ApiParam("Negative class ratio in training sample") @QueryParam("negativeClassRatio") Double negativeClassRatio,
            @ApiParam("Training sample size") @QueryParam("trainSampleSize") Integer trainSampleSize,
            @ApiParam("Matched ids ratio") @QueryParam("matchedIdsRatio") Double matchedIdsRatio,
            @ApiParam("Top features separated by ,") @QueryParam("topFeatures") String topFeatures
    ) {
        TMetrics.Builder metrics = TMetrics.newBuilder()
                .setSampleId(sampleId)
                .setRocAuc(rocAuc)
                .setAccuracy(accuracy)
                .setPositiveClassRatio(positiveClassRatio)
                .setNegativeClassRatio(1 - positiveClassRatio)
                .setTrainSampleSize(trainSampleSize)
                .setMatchedIdsRatio(matchedIdsRatio)
                .addAllTopFeatures(Arrays.asList(topFeatures.split(", ")));

        lab().trainingSamples().writeTrainingSampleMetrics(metrics.build());
    }

    @POST
    @Path("write_error")
    @ApiOperation(value = "Write training error")
    public void writeTrainingSampleError(
            @ApiParam("Training Sample Id") @QueryParam("sampleId") String sampleId,
            @ApiParam("Error message") @QueryParam("message") String message
    ) {
        TTrainingError.Builder trainingError = TTrainingError.newBuilder()
                .setSampleId(sampleId)
                .setMessage(message);

        lab().trainingSamples().writeTrainingSampleError(trainingError.build());
    }

    @GET
    @Path("industries")
    @ApiOperation(value = "Retrieve industries descriptions")
    public List<TIndustry> getTrainingSamplesIndustries() {
        return lab().trainingSamples().getAllIndustries();
    }

    @POST
    @Path("add_industry")
    @ApiOperation(value = "Add industry description")
    public void addNewIndustry(
            @ApiParam("Industry name") @QueryParam("name") String name,
            @ApiParam("Model name (identifier)") @QueryParam("modelName") String modelName,
            @ApiParam("Optimization objective description") @QueryParam("objective") String objective,
            @ApiParam("Positive conversions description") @QueryParam("positiveConversions") String positiveConversions,
            @ApiParam("Negative conversions description") @QueryParam("negativeConversions") String negativeConversions

            ) {
        TIndustry.Builder industry = TIndustry.newBuilder()
                .setName(name)
                .setModelName(modelName)
                .setObjective(objective)
                .setPositiveConversions(positiveConversions)
                .setNegativeConversions(negativeConversions);

        lab().trainingSamples().addIndustry(industry.build());
    }

    @POST
    @Path("add_trainable_segment_user_set_id")
    @ApiOperation(value = "Add trainable segment user_set_id")
    public TTrainableSegment addSegmentDescription(
            @ApiParam("Training sample id") @QueryParam("sampleId") String sampleId,
            @ApiParam(value="Target type (positive/negative)", allowableValues="POSITIVE, NEGATIVE") @QueryParam("targetType") String targetType,
            @ApiParam(value="Origin type (initial/modeled)", allowableValues="INITIAL, MODELED") @QueryParam("originType") String originType,
            @ApiParam("Segment ids count") @QueryParam("rowsCount") Integer rowsCount,
            @ApiParam("UserSetId from Siberia") @QueryParam("userSetId") String userSetId
    ) {
        return lab().trainingSamples().addSegmentDescription(sampleId, ETrainableSegmentTargetType.valueOf(targetType),
                ETrainableSegmentOriginType.valueOf(originType), rowsCount, userSetId);
    }

    @POST
    @Path("compute_trainable_segment_description")
    @ApiOperation(value = "Add trainable segment description from table")
    public TTrainableSegment addSegmentDescriptionFromTable(
            @ApiParam("Training sample id") @QueryParam("sampleId") String sampleId,
            @ApiParam(value="Target type (positive/negative)", allowableValues="POSITIVE, NEGATIVE") @QueryParam("targetType") String targetType,
            @ApiParam(value="Origin type (initial/modeled)", allowableValues="INITIAL, MODELED") @QueryParam("originType") String originType
    ) {
        return lab().trainingSamples().addSegmentDescriptionFromTable(sampleId,
                ETrainableSegmentTargetType.valueOf(targetType), ETrainableSegmentOriginType.valueOf(originType));
    }

    @POST
    @Path("report_status")
    @ApiOperation(value = "Report training status")
    public void reportTrainingStatus(
            @ApiParam("Training sample id") @QueryParam("sampleId") String sampleId,
            @ApiParam(value="Training task status", allowableValues="TRAINING, READY") @QueryParam("status") String status
    ) {
        lab().trainingSamples().updateSampleState(sampleId, ETrainingSampleState.valueOf(status));
    }
}
