package ru.yandex.crypta.graph2.model.matching.component.score;

import java.util.Map;
import java.util.TreeMap;

import ru.yandex.crypta.graph2.model.matching.component.Component;
import ru.yandex.crypta.graph2.model.matching.component.GraphInfo;
import ru.yandex.crypta.graph2.model.matching.component.score.extractors.IdsCountExtractor;
import ru.yandex.crypta.graph2.model.matching.score.MetricsTree;
import ru.yandex.misc.lang.Check;

public class TrueHistogramCountsScoringStrategy implements ComponentScoringStrategy {
    private String name;
    private String description;
    private TreeMap<Integer, Double> countsProbs = new TreeMap<>();
    private TreeMap<Integer, Double> intervalsProbs = new TreeMap<>();
    private double theRestProb;
    private boolean penalizeEmpty = false;
    private IdsCountExtractor toCountFunc;
    private static final Integer countUpperBound = Integer.MAX_VALUE;

    private TrueHistogramCountsScoringStrategy(String name, String description) {
        this.name = name;
        this.description = description;
    }

    public static Builder metric(String name, String description) {
        return new Builder(name, description);
    }

    @Override
    public MetricsTree scoreTree(Component component, GraphInfo graphInfo) {
        int count = toCountFunc.apply(component, graphInfo);
        if (count == 0 && penalizeEmpty) {
            return new MetricsTree(-1);
        } else {
            Map.Entry<Integer, Double> countToProb = countsProbs.ceilingEntry(count);
            if (countToProb == null) {
                // out of specified range
                return new MetricsTree(theRestProb);
            } else {
                return new MetricsTree(countToProb.getValue());
            }
        }
    }

    @Override
    public String getName() {
        return name;
    }

    public String getDescription() {
        return description;
    }

    public TreeMap<Integer, Double> getCountsProbs() {
        return countsProbs;
    }

    public double getTheRestProb() {
        return theRestProb;
    }

    public static class Builder {
        private TrueHistogramCountsScoringStrategy instance;

        private Builder(String name, String description) {
            instance = new TrueHistogramCountsScoringStrategy(name, description);
        }

        private void putCountProb(int count, double prob) {
            if (this.instance.countsProbs.containsKey(count)) {
                throw new IllegalStateException("Interval already exists, they shouldn't intersect");
            }
            this.instance.countsProbs.put(count, prob);
        }

        private void putIntervalProb(int count, double prob) {
            if (this.instance.intervalsProbs.containsKey(count)) {
                throw new IllegalStateException("Interval already exists, they shouldn't intersect");
            }
            this.instance.intervalsProbs.put(count, prob);
        }

        public Builder scoringCount(IdsCountExtractor toCountFunc) {
            instance.toCountFunc = toCountFunc;
            return this;
        }

        public Builder lessOrEqualAs(int count, double prob) {
            int prev = 0;
            if (!instance.countsProbs.isEmpty()) {
                prev = instance.countsProbs.lowerEntry(count).getKey();
            }

            double valueProb = 0;
            if (count - prev != 0)
                valueProb = prob / (count - prev);

            putIntervalProb(count, prob);
            putCountProb(count, valueProb);
            return this;
        }

        public Builder uniformIncreasingRange(int from, int to, double probabilityRange) {
            Check.isTrue(to > from);
            int range = to - from;
            double probabilityInterval = probabilityRange / ((1 + range) * range / 2);
            double currentProbability = probabilityInterval;
            for (int i = from; i < to; i++) {
                putIntervalProb(i + 1, currentProbability);
                putCountProb(i + 1, currentProbability);
                currentProbability += probabilityInterval;
            }

            return this;
        }

        public Builder uniformDecreasingRange(int from, int to, double probabilityRange) {
            Check.isTrue(to > from);
            int range = to - from;
            double probabilityInterval = probabilityRange / ((1 + range) * range / 2);
            double currentProbability = probabilityInterval;
            for (int i = to; i > from; i--) {
                putIntervalProb(i, currentProbability);
                putCountProb(i, currentProbability);
                currentProbability += probabilityInterval;
            }

            return this;
        }

        public Builder andTheRest() {
            double intervalRestProb = 1 - instance.intervalsProbs.values().stream().mapToDouble(Double::doubleValue).sum();

            double lastCount = 0;
            if (!instance.countsProbs.isEmpty()) {
                lastCount = instance.intervalsProbs.lastKey();
            }

            Check.isTrue(lastCount < countUpperBound);

            putIntervalProb(-1, intervalRestProb);
            double restProb = intervalRestProb / (countUpperBound - lastCount);
            instance.theRestProb = restProb;
            return this;
        }

        public Builder andPenalizeEmpty() {
            instance.penalizeEmpty = true;
            return this;
        }

        public TrueHistogramCountsScoringStrategy build() {
            double totalProbSum = sumOfWeights();
            try {
                Check.equals(1.0d, totalProbSum, 0.001);
            } catch (Exception e) {
                String message = String.format("Histogram %s should sum up to 1.0, but is %f instead", instance.name, totalProbSum);
                throw new IllegalStateException(message);
            }
            return instance;
        }

        private double sumOfWeights() {
            return instance.intervalsProbs.values().stream()
                    .mapToDouble(w -> w).sum();
        }
    }
}
