package ru.yandex.crypta.lab.tables;

import javax.ws.rs.core.SecurityContext;

import org.jooq.Condition;
import org.jooq.Configuration;
import org.jooq.Delete;
import org.jooq.Field;
import org.jooq.Insert;
import org.jooq.Query;
import org.jooq.Record;
import org.jooq.Select;
import org.jooq.Table;
import org.jooq.impl.DSL;

import ru.yandex.crypta.common.data.GenericTable;
import ru.yandex.crypta.lab.proto.AccessLevel;
import ru.yandex.crypta.lab.proto.TTrainingSample;
import ru.yandex.crypta.lib.python.custom_ml.proto.ETrainingSampleState;
import ru.yandex.crypta.lib.python.custom_ml.proto.TTrainingError;

public class TrainingSamplesTable extends GenericTable<TTrainingSample> {
    public static final Table<Record> TABLE = DSL.table("lab_training_samples");
    public static final Field<String> ID = DSL.field(DSL.name(TABLE.getName(), "id"), String.class);
    private static final Field<String> NAME = DSL.field(DSL.name(TABLE.getName(), "name"), String.class);
    public static final Field<String> AUTHOR = DSL.field(DSL.name(TABLE.getName(), "author"), String.class);
    public static final Field<String> ACCESS_LEVEL = DSL.field(DSL.name(TABLE.getName(), "access_level"), String.class);
    public static final Field<Long> TTL = DSL.field(DSL.name(TABLE.getName(), "ttl"), Long.class);
    private static final Field<Long> CREATED = DSL.field(DSL.name(TABLE.getName(), "created"), Long.class);
    private static final Field<Long> MODIFIED = DSL.field(DSL.name(TABLE.getName(), "modified"), Long.class);
    private static final Field<String> TRAINING_ERROR = DSL.field(DSL.name(TABLE.getName(), "training_error"), String.class);
    private static final Field<String> STATUS = DSL.field(DSL.name(TABLE.getName(), "status"), String.class);
    private static final Field<String> MODEL_NAME = DSL.field(DSL.name(TABLE.getName(), "model_name"), String.class);
    private static final Field<String> PARTNER = DSL.field(DSL.name(TABLE.getName(), "partner"), String.class);

    public static final Long DEFAULT_TTL = 604_800L;

    private final SecurityContext securityContext;

    public TrainingSamplesTable(Configuration configuration, SecurityContext securityContext) {
        super(configuration, TTrainingSample.class);
        this.securityContext = securityContext;
    }

    @Override
    protected TTrainingSample read(Record record) {
        TTrainingSample.Builder builder = TTrainingSample.newBuilder()
                .setId(record.get(ID))
                .setName(record.get(NAME))
                .setAuthor(record.get(AUTHOR))
                .setAccessLevel(AccessLevel.valueOf(record.get(ACCESS_LEVEL)))
                .setTtl(record.get(TTL))
                .setStatus(ETrainingSampleState.valueOf(record.get(STATUS)));

        if (record.get(TRAINING_ERROR) != null) {
            builder.setTrainingError(record.get(TRAINING_ERROR));
        }
        if (record.get(MODEL_NAME) != null && record.get(PARTNER) != null) {
            builder.setModelName(record.get(MODEL_NAME)).setPartner(record.get(PARTNER));
        }

        builder.getTimestampsBuilder()
                .setModified(record.get(MODIFIED))
                .setCreated(record.get(CREATED));

        return builder.build();
    }

    @Override
    public Select<Record> selectQuery() {
        return dsl.selectFrom(TABLE);
    }

    public Select<Record> selectAccessible() {
        return dsl.selectFrom(TABLE)
                .where(TrainingSamplesAcl.isListedFor(securityContext));
    }

    public Select<Record> selectByIdQuery(String id) {
        return dsl.selectFrom(TABLE)
                .where(ID.eq(id))
                .and(TrainingSamplesAcl.isAccessibleDirectlyBy(securityContext));
    }

    public Insert<Record> insertQuery(TTrainingSample sample) {
        return dsl.insertInto(TABLE)
                .set(ID, sample.getId())
                .set(NAME, sample.getName())
                .set(AUTHOR, sample.getAuthor())
                .set(ACCESS_LEVEL, sample.getAccessLevel().name())
                .set(TTL, sample.getTtl())
                .set(CREATED, sample.getTimestamps().getCreated())
                .set(MODIFIED, sample.getTimestamps().getModified())
                .set(MODEL_NAME, sample.getModelName())
                .set(PARTNER, sample.getPartner())
                .set(STATUS, sample.getStatus().name());
    }

    public Delete<Record> deleteByIdQuery(String id) {
        return dsl.deleteFrom(TABLE)
                .where(ID.eq(id))
                .and(TrainingSamplesAcl.isModifiableBy(securityContext));
    }

    public Select<Record> selectOutdatedSamplesQuery(Long now) {
        Condition overdue = DSL
                .value(now)
                .minus(TTL)
                .greaterOrEqual(MODIFIED);
        return dsl.selectFrom(TABLE)
                .where(overdue);
    }

    public Query addError(TTrainingError trainingError) {
        return dsl.update(TABLE)
                .set(TRAINING_ERROR, trainingError.getMessage())
                .set(STATUS, ETrainingSampleState.ERROR.name())
                .where(ID.eq(trainingError.getSampleId()));
    }

    public Query updateState(String id, ETrainingSampleState state) {
        return dsl.update(TABLE)
                .set(STATUS, state.name())
                .where(ID.eq(id));
    }
}
