package ru.yandex.crypta.graph2.model.matching.component.score;

import java.util.Map;
import java.util.TreeMap;

import ru.yandex.crypta.graph2.model.matching.component.Component;
import ru.yandex.crypta.graph2.model.matching.component.GraphInfo;
import ru.yandex.crypta.graph2.model.matching.component.score.extractors.IdsCountExtractor;
import ru.yandex.crypta.graph2.model.matching.score.MetricsTree;
import ru.yandex.misc.lang.Check;

public class HistogramCountsScoringStrategy implements ComponentScoringStrategy {
    private String name;
    private String description;
    private TreeMap<Integer, Double> countsProbs = new TreeMap<>();
    private double theRestProb;
    private boolean penalizeEmpty = false;
    private IdsCountExtractor toCountFunc;

    private HistogramCountsScoringStrategy(String name, String description) {
        this.name = name;
        this.description = description;
    }

    public static Builder metric(String name, String description) {
        return new Builder(name, description);
    }

    @Override
    public MetricsTree scoreTree(Component component, GraphInfo graphInfo) {
        int count = toCountFunc.apply(component, graphInfo);
        if (count == 0 && penalizeEmpty) {
            return new MetricsTree(-1);
        } else {
            Map.Entry<Integer, Double> countToProb = countsProbs.ceilingEntry(count);
            if (countToProb == null) {
                // out of specified range
                return new MetricsTree(theRestProb);
            } else {
                return new MetricsTree(countToProb.getValue());
            }
        }
    }

    @Override
    public String getName() {
        return name;
    }

    public String getDescription() {
        return description;
    }

    public TreeMap<Integer, Double> getCountsProbs() {
        return countsProbs;
    }

    public double getTheRestProb() {
        return theRestProb;
    }

    public static class Builder {
        private HistogramCountsScoringStrategy instance;

        private Builder(String name, String description) {
            instance = new HistogramCountsScoringStrategy(name, description);
        }

        private void putProb(int count, double prob) {
            if (this.instance.countsProbs.containsKey(count)) {
                throw new IllegalStateException("Interval already exists, they shouldn't intersect");
            }
            this.instance.countsProbs.put(count, prob);
        }

        public Builder scoringCount(IdsCountExtractor toCountFunc) {
            instance.toCountFunc = toCountFunc;
            return this;
        }

        public Builder lessOrEqualAs(int count, double prob) {
            putProb(count, prob);
            return this;
        }

        public Builder uniformIncreasingRange(int from, int to, double probabilityRange) {
            Check.isTrue(to > from);
            int range = to - from;
            double probabilityInterval = probabilityRange / ((1 + range) * range / 2);
            double currentProbability = probabilityInterval;
            for (int i = from; i < to; i++) {
                putProb(i + 1, currentProbability);
                currentProbability += probabilityInterval;
            }

            return this;
        }

        public Builder uniformDecreasingRange(int from, int to, double probabilityRange) {
            Check.isTrue(to > from);
            int range = to - from;
            double probabilityInterval = probabilityRange / ((1 + range) * range / 2);
            double currentProbability = probabilityInterval;
            for (int i = to; i > from; i--) {
                putProb(i, currentProbability);
                currentProbability += probabilityInterval;
            }

            return this;
        }

        public Builder andTheRestAs(double restProb) {
            instance.theRestProb = restProb;
            return this;
        }

        public Builder andPenalizeEmpty() {
            instance.penalizeEmpty = true;
            return this;
        }

        public HistogramCountsScoringStrategy build() {
            double totalProbSum = sumOfWeights();
            try {
                Check.equals(1.0d, totalProbSum, 0.001);
            } catch (Exception e) {
                String message = String.format("Histogram %s should sum up to 1.0, but is %f instead", instance.name, totalProbSum);
                throw new IllegalStateException(message);
            }
            return instance;
        }

        private double sumOfWeights() {
            return instance.countsProbs.values().stream()
                    .mapToDouble(w -> w).sum() + instance.theRestProb;
        }
    }
}
