package ru.yandex.crypta.lab.tables;

import java.util.Arrays;

import org.jooq.Configuration;
import org.jooq.Delete;
import org.jooq.Field;
import org.jooq.Insert;
import org.jooq.Record;
import org.jooq.Select;
import org.jooq.Table;
import org.jooq.impl.DSL;

import ru.yandex.crypta.common.data.GenericTable;
import ru.yandex.crypta.lib.python.custom_ml.proto.TMetrics;

public class MetricsTable extends GenericTable<TMetrics> {
    public static final Table<Record> TABLE = DSL.table("lab_training_sample_metrics");

    public static final Field<String> SAMPLE_ID = DSL.field(DSL.name(TABLE.getName(), "sample_id"), String.class);
    public static final Field<Double> ROC_AUC = DSL.field(DSL.name(TABLE.getName(), "roc_auc"), Double.class);
    public static final Field<Double> ACCURACY = DSL.field(DSL.name(TABLE.getName(), "accuracy"), Double.class);
    public static final Field<Double> POSITIVE_CLASS_RATIO = DSL.field(DSL.name(TABLE.getName(), "positive_class_ratio"), Double.class);
    public static final Field<Integer> TRAIN_SAMPLE_SIZE = DSL.field(DSL.name(TABLE.getName(), "train_sample_size"), Integer.class);
    public static final Field<Double> MATCHED_IDS_RATIO = DSL.field(DSL.name(TABLE.getName(), "matched_ids_ratio"), Double.class);
    public static final Field<String> TOP_FEATURES = DSL.field(DSL.name(TABLE.getName(), "top_features"), String.class);

    private static final String DELIMITER = ", ";

    public MetricsTable(Configuration configuration) {
        super(configuration, TMetrics.class);
    }
    @Override
    protected TMetrics read(Record record) {
        TMetrics.Builder builder = TMetrics.newBuilder()
                .setSampleId(record.get(SAMPLE_ID))
                .setRocAuc(record.get(ROC_AUC))
                .setAccuracy(record.get(ACCURACY))
                .setPositiveClassRatio(record.get(POSITIVE_CLASS_RATIO))
                .setNegativeClassRatio(1 - record.get(POSITIVE_CLASS_RATIO))
                .setTrainSampleSize(record.get(TRAIN_SAMPLE_SIZE))
                .setMatchedIdsRatio(record.get(MATCHED_IDS_RATIO))
                .addAllTopFeatures(Arrays.asList(record.get(TOP_FEATURES).split(DELIMITER)));

        return builder.build();
    }

    @Override
    public Select<Record> selectQuery() { return dsl.selectFrom(TABLE); }

    public Insert<Record> insertQuery(TMetrics metrics) {
        return dsl.insertInto(TABLE)
                .set(SAMPLE_ID, metrics.getSampleId())
                .set(ROC_AUC, metrics.getRocAuc())
                .set(ACCURACY, metrics.getAccuracy())
                .set(POSITIVE_CLASS_RATIO, metrics.getPositiveClassRatio())
                .set(TRAIN_SAMPLE_SIZE, metrics.getTrainSampleSize())
                .set(MATCHED_IDS_RATIO, metrics.getMatchedIdsRatio())
                .set(TOP_FEATURES, String.join(DELIMITER, metrics.getTopFeaturesList()));
    }

    public Select<Record> selectByIdQuery(String sampleId) {
        return dsl.selectFrom(TABLE).where(SAMPLE_ID.equal(sampleId));
    }

    public boolean isMetricsReadyForSample(String sampleId) {
        return selectByIdQuery(sampleId).stream().count() > 0;
    }

    public Delete<Record> deleteByIdQuery(String sampleId) {
        return dsl.deleteFrom(TABLE).where(SAMPLE_ID.eq(sampleId));
    }
}
