package ru.yandex.crypta.lab.tables;

import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import org.jooq.Configuration;
import org.jooq.DeleteConditionStep;
import org.jooq.Field;
import org.jooq.InsertSetMoreStep;
import org.jooq.Record;
import org.jooq.Select;
import org.jooq.SelectConditionStep;
import org.jooq.Table;
import org.jooq.impl.DSL;

import ru.yandex.crypta.common.data.GenericTable;
import ru.yandex.crypta.lab.proto.ModelSegmentRelation;
import ru.yandex.crypta.lab.proto.Thresholds;
import ru.yandex.crypta.lab.proto.Timestamps;

public class ModelSegmentRelationsTable extends GenericTable<ModelSegmentRelation> {

    public static final Table<Record> TABLE = DSL.table("api_model_segment_relations");
    public static final Field<String> MODEL_ID = DSL.field(DSL.name(TABLE.getName(), "model_id"), String.class);
    public static final Field<String> SEGMENT_ID = DSL.field(DSL.name(TABLE.getName(), "segment_id"), String.class);
    public static final Field<Double> LOW_THRESHOLD =
            DSL.field(DSL.name(TABLE.getName(), "low_threshold"), Double.class);
    public static final Field<Double> HIGH_THRESHOLD =
            DSL.field(DSL.name(TABLE.getName(), "high_threshold"), Double.class);
    public static final Field<Long> CREATED = DSL.field(DSL.name(TABLE.getName(), "created"), Long.class);
    public static final Field<Long> MODIFIED = DSL.field(DSL.name(TABLE.getName(), "modified"), Long.class);
    private static final Field<String[]> MODEL_IDS_ARRAY = DSL.arrayAgg(ModelSegmentRelationsTable.MODEL_ID);
    private static final Field<String[]> SEGMENT_IDS_ARRAY = DSL.arrayAgg(ModelSegmentRelationsTable.SEGMENT_ID);
    private static final Field<Long[]> CREATED_ARRAY = DSL.arrayAgg(ModelSegmentRelationsTable.CREATED);
    private static final Field<Long[]> MODIFIED_ARRAY = DSL.arrayAgg(ModelSegmentRelationsTable.MODIFIED);
    private static final Field<Double[]> LOW_THRESHOLD_ARRAY = DSL.arrayAgg(ModelSegmentRelationsTable.LOW_THRESHOLD);
    private static final Field<Double[]> HIGH_THRESHOLD_ARRAY = DSL.arrayAgg(ModelSegmentRelationsTable.HIGH_THRESHOLD);

    public ModelSegmentRelationsTable(Configuration configuration) {
        super(configuration, ModelSegmentRelation.class);
    }

    public static Stream<ModelSegmentRelation> readAggregated(Record record) {
        String[] modelIds = record.get(MODEL_IDS_ARRAY);
        String[] segmentIds = record.get(SEGMENT_IDS_ARRAY);
        Long[] relationCreated = record.get(CREATED_ARRAY);
        Long[] relationModified = record.get(MODIFIED_ARRAY);
        Double[] lowThresholds = record.get(LOW_THRESHOLD_ARRAY);
        Double[] highThresholds = record.get(HIGH_THRESHOLD_ARRAY);
        return IntStream.range(0, segmentIds.length)
                .filter(index -> Objects.nonNull(segmentIds[index]) && Objects.nonNull(modelIds[index]))
                .mapToObj(index -> {
                    ModelSegmentRelation.Builder relation = ModelSegmentRelation
                            .newBuilder()
                            .setModelId(modelIds[index])
                            .setSegmentId(segmentIds[index]);
                    relation.getThresholdsBuilder()
                            .setLow(Optional.ofNullable(lowThresholds[index]).orElse(0.0))
                            .setHigh(Optional.ofNullable(highThresholds[index]).orElse(1.0));
                    relation.getTimestampsBuilder()
                            .setCreated(Optional.ofNullable(relationCreated[index]).orElse(0L))
                            .setModified(Optional.ofNullable(relationModified[index]).orElse(0L));
                    return relation.build();
                }).distinct();
    }

    public static Field[] aggregateFields() {
        return new Field[]{
                MODEL_IDS_ARRAY,
                SEGMENT_IDS_ARRAY,
                CREATED_ARRAY,
                MODIFIED_ARRAY,
                LOW_THRESHOLD_ARRAY,
                HIGH_THRESHOLD_ARRAY
        };
    }

    @Override
    protected ModelSegmentRelation read(Record record) {
        ModelSegmentRelation.Builder prototype = ModelSegmentRelation.newBuilder();
        prototype.setModelId(record.get(MODEL_ID));
        prototype.setSegmentId(record.get(SEGMENT_ID));

        Double lowThreshold = record.get(LOW_THRESHOLD);
        Double highThreshold = record.get(HIGH_THRESHOLD);
        if (lowThreshold != null && highThreshold != null) {
            Thresholds.Builder thresholds = Thresholds.newBuilder();
            thresholds.setLow(lowThreshold);
            thresholds.setHigh(highThreshold);
            prototype.setThresholds(thresholds);
        }

        prototype.setTimestamps(Timestamps
                .newBuilder()
                .setCreated(record.get(CREATED))
                .setModified(record.get(MODIFIED))
        );

        return prototype.build();
    }

    public SelectConditionStep<Record> selectByModelIdQuery(String modelId) {
        return dsl.selectFrom(TABLE)
                .where(MODEL_ID.equal(modelId));
    }

    public DeleteConditionStep<Record> deleteByModelIdQuery(String modelId) {
        return dsl.deleteFrom(TABLE)
                .where(MODEL_ID.equal(modelId));
    }

    public InsertSetMoreStep<Record> genericInsertQuery(Table<Record> table, ModelSegmentRelation relation) {
        InsertSetMoreStep<Record> query = dsl.insertInto(table)
                .set(MODEL_ID, relation.getModelId())
                .set(SEGMENT_ID, relation.getSegmentId());
        if (relation.hasThresholds()) {
            query = query.set(LOW_THRESHOLD, relation.getThresholds().getLow())
                    .set(HIGH_THRESHOLD, relation.getThresholds().getHigh());
        }
        return query.set(CREATED, relation.getTimestamps().getCreated())
                .set(MODIFIED, relation.getTimestamps().getModified());
    }

    public InsertSetMoreStep<Record> insertQuery(ModelSegmentRelation relation) {
        return genericInsertQuery(TABLE, relation);
    }

    public DeleteConditionStep<Record> deleteByModelAndSegmentIdQuery(String modelId, String segmentId) {
        return dsl.deleteFrom(TABLE)
                .where(MODEL_ID.equal(modelId))
                .and(SEGMENT_ID.equal(segmentId));
    }

    public SelectConditionStep<Record> selectBySegmentIdQuery(String segmentId) {
        return dsl.selectFrom(TABLE)
                .where(SEGMENT_ID.equal(segmentId));
    }

    public DeleteConditionStep<Record> deleteBySegmentIdQuery(String segmentId) {
        return dsl.deleteFrom(TABLE)
                .where(SEGMENT_ID.equal(segmentId));
    }

    @Override
    public Select<Record> selectQuery() {
        return dsl.selectFrom(TABLE);
    }
}
