package ru.yandex.crypta.graph2.model.soup.edge.weight.estimator;

import java.time.LocalDate;
import java.time.ZoneOffset;
import java.util.List;

import com.google.protobuf.InvalidProtocolBufferException;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.crypta.graph2.dao.Dao;
import ru.yandex.crypta.graph2.model.soup.edge.EdgeProtoHelper;
import ru.yandex.crypta.graph2.model.soup.proto.Edge;
import ru.yandex.crypta.graph2.model.soup.proto.EdgeType;
import ru.yandex.crypta.soup.edge_weights.proto.TStatsQuery;
import ru.yandex.crypta.soup.edge_weights.proto.TTotalHistogram;
import ru.yandex.inside.yt.kosher.Yt;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.impl.operations.utils.YtSerializable;
import ru.yandex.inside.yt.kosher.tables.YTableEntryTypes;
import ru.yandex.inside.yt.kosher.ytree.YTreeMapNode;


public class SurvivalEdgeModel implements YtSerializable {

    static {
        System.loadLibrary("soup-model_java");
    }

    private final MapF<EdgeType, TTotalHistogram> histogramsMap = Cf.hashMap();
    private final LocalDate localDate;

    public SurvivalEdgeModel(LocalDate localDate) {
        this.localDate = localDate;
    }

    public SurvivalEdgeModel() {
        this(LocalDate.now());
    }

    public SurvivalEdgeModel(Yt yt, YPath table, LocalDate localDate) {
        this(localDate);
        yt.tables().read(table, YTableEntryTypes.YSON, this::readCallback);
    }

    public SurvivalEdgeModel(Yt yt, YPath table) {
        this(yt, table, LocalDate.now());
    }

    public SurvivalEdgeModel(Dao dao, YPath table, LocalDate localDate) {
        this(dao.yt(), table, localDate);
    }

    public SurvivalEdgeModel(Dao dao, YPath table) {
        this(dao, table, LocalDate.now());
    }

    native public static double getSurvivalFunctionValue(long timestamp, byte[] statsQuery);

    private void readCallback(YTreeMapNode row) {
        try {
            String id1Type = row.getOrThrow("id1Type").stringValue();
            String id2Type = row.getOrThrow("id2Type").stringValue();
            String sourceType = row.getOrThrow("sourceType").stringValue();
            String logSource = row.getOrThrow("logSource").stringValue();

            histogramsMap.put(
                    EdgeProtoHelper.createEdgeType(id1Type, id2Type, sourceType, logSource),
                    TTotalHistogram.parseFrom(row.getOrThrow("TotalHistogram").bytesValue())
            );
        } catch (InvalidProtocolBufferException e) {
            throw new RuntimeException(e);
        }
    }

    public double getEdgeWeight(Edge edge) {
        return getEdgeWeight(getLocalTimestamp(), getStatsQuery(edge));
    }

    public double getDefaultEdgeWeight(Edge edge) {
        TStatsQuery.Builder statsQueryBuilder = getStatsQuery(edge).toBuilder();
        statsQueryBuilder.clearTotalHistogram();
        return getEdgeWeight(getLocalTimestamp(), statsQueryBuilder.build());
    }

    public double getEdgeWeight(long timestamp, TStatsQuery query) {
        // TODO: pass dates and histogram separately, thus you don't need to serialize histogram each time
        return getSurvivalFunctionValue(timestamp, query.toByteArray());
    }

    public TStatsQuery getStatsQuery(Edge edge) {
        TStatsQuery.Builder query = TStatsQuery.newBuilder().setTotalHistogram(getTotalHistogram(edge));
        List<String> dates = EdgeProtoHelper.getDates(edge);
        dates.forEach(query::addDate);
        return query.build();
    }

    private long getLocalTimestamp() {
        return localDate.atStartOfDay(ZoneOffset.MAX).toInstant().toEpochMilli() / 1000;
    }

    private TTotalHistogram getTotalHistogram(Edge edge) {
        EdgeType edgeType = EdgeProtoHelper.getEdgeType(edge);
        return histogramsMap.getOrElseApply(edgeType, () -> TTotalHistogram.newBuilder().build());
    }
}
