package ru.yandex.solomon.math.stat;

import java.util.EnumSet;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import javax.annotation.ParametersAreNonnullByDefault;

import ru.yandex.solomon.model.point.HasMutableTs;
import ru.yandex.solomon.model.timeseries.iterator.GenericIterator;

/**
 * @author Ivan Tsybulin
 */
@ParametersAreNonnullByDefault
public abstract class HistogramCumulativeDistributionIterator<HistogramPoint extends HasMutableTs> implements GenericIterator<VectorPoint> {
    private final GenericIterator<HistogramPoint> iterator;
    private final int[] lowerIdxToSortedIdx;
    private final int[] upperIdxToSortedIdx;
    private final int totalIdx;
    private final double[] sortedBounds;

    private final HistogramPoint point;
    private final double[] valuesForSortedBounds;

    public enum Options {
        NORMALIZE,
        ZERO_UPPER_BOUND_CORRECTION,
    }

    private static class TaggedBound {
        public enum Kind {
            LOWER,
            UPPER,
            TOTAL
        }
        private final Kind kind;
        private final int originalIndex;
        private final double value;

        private TaggedBound(Kind kind, int originalIndex, double value) {
            this(kind, originalIndex, value, false);
        }

        // Zero correction is needed for LogHistogram when upper bound is zero
        private TaggedBound(Kind kind, int originalIndex, double value, boolean zeroCorrection) {
            this.kind = kind;
            this.originalIndex = originalIndex;
            this.value = (zeroCorrection && value == 0) ? Double.MIN_VALUE : value;
        }

        private double getValue() {
            return value;
        }
    }

    public HistogramCumulativeDistributionIterator(
            GenericIterator<HistogramPoint> iterator,
            double[] lowerBounds,
            double[] upperBounds,
            EnumSet<Options> flags)
    {
        this.iterator = iterator;
        boolean normalize = flags.contains(Options.NORMALIZE);
        boolean zeroCorrection = flags.contains(Options.ZERO_UPPER_BOUND_CORRECTION);
        var groupedSortedBounds = Stream.concat(
                    normalize
                            ? Stream.of(new TaggedBound(TaggedBound.Kind.TOTAL, 0, Double.POSITIVE_INFINITY))
                            : Stream.empty(),
                    Stream.concat(
                            IntStream.range(0, lowerBounds.length)
                                    .mapToObj(i -> new TaggedBound(TaggedBound.Kind.LOWER, i, lowerBounds[i])),
                            IntStream.range(0, upperBounds.length)
                                    .mapToObj(i -> new TaggedBound(TaggedBound.Kind.UPPER, i, upperBounds[i], zeroCorrection))
                    ))
            .collect(Collectors.groupingBy(TaggedBound::getValue))
            .entrySet().stream()
            .sorted(Map.Entry.comparingByKey())
            .collect(Collectors.toList());

        this.lowerIdxToSortedIdx = new int[lowerBounds.length];
        this.upperIdxToSortedIdx = new int[upperBounds.length];
        this.sortedBounds = new double[groupedSortedBounds.size()];

        int totalIdx = -1;
        for (int i = 0; i < sortedBounds.length; i++) {
            var entry = groupedSortedBounds.get(i);
            sortedBounds[i] = entry.getKey();
            for (var tag : entry.getValue()) {
                switch (tag.kind) {
                    case LOWER:
                        lowerIdxToSortedIdx[tag.originalIndex] = i;
                        break;
                    case UPPER:
                        upperIdxToSortedIdx[tag.originalIndex] = i;
                        break;
                    case TOTAL:
                        totalIdx = i;
                        break;
                }
            }
        }
        this.totalIdx = totalIdx;

        this.point = newPoint();
        this.valuesForSortedBounds = new double[sortedBounds.length];
    }

    @Override
    public boolean next(VectorPoint target) {
        if (!iterator.next(point)) {
            return false;
        }

        target.tsMillis = point.getTsMillis();
        computeCumulativeDistribution(point, valuesForSortedBounds, sortedBounds);
        for (int i = 0; i < target.values.length; i++) {
            target.values[i] = valuesForSortedBounds[upperIdxToSortedIdx[i]] - valuesForSortedBounds[lowerIdxToSortedIdx[i]];
        }
        if (totalIdx != -1) { // Normalization
            double scale = 100d / valuesForSortedBounds[totalIdx];
            for (int i = 0; i < target.values.length; i++) {
                target.values[i] *= scale;
            }
        }

        return true;
    }

    public abstract HistogramPoint newPoint();

    public abstract void computeCumulativeDistribution(HistogramPoint point, double[] cumulativeCount, double[] sortedBounds);

    @Override
    public int estimatePointsCount() {
        return iterator.estimatePointsCount();
    }
}
