package ru.yandex.solomon.math.stat;

import javax.annotation.ParametersAreNonnullByDefault;

import ru.yandex.solomon.model.timeseries.GraphData;
import ru.yandex.solomon.model.timeseries.GraphDataArrayList;
import ru.yandex.solomon.util.collection.array.LongArrayView;

/**
 * @author Ivan Tsybulin
 */
@ParametersAreNonnullByDefault
public class SeasonalTrend {
    private SampleBucket[] buckets;
    private CalendarSplitter calendar;

    public SeasonalTrend(CalendarSplitter calendar, double dropFraction) {
        buckets = new SampleBucket[calendar.bucketCount()];
        for (int i = 0; i < buckets.length; i++) {
            buckets[i] = new SampleBucket(dropFraction);
        }
        this.calendar = calendar;
    }

    /**
     * If both values are not NaN then interpolate linearly.
     * Otherwise use nearest-neighbor interpolation
     */
    private double interpolate(double v1, double v2, double w1, double w2) {
        if (Double.isNaN(v1))
            return w2 > w1 ? v2 : Double.NaN;
        if (Double.isNaN(v2))
            return w1 > w2 ? v1 : Double.NaN;
        return w1 * v1 + w2 * v2;
    }

    public void fit(GraphData fit) {
        GraphData noNanFitData = fit.filterNonNan();

        noNanFitData.visit((pointTime, value) -> {
            int bucketIndex = calendar.bucketOf(pointTime);
            buckets[bucketIndex].add(value);
        });
    }

    public GraphData predictMean(LongArrayView predictTimes) {
        GraphDataArrayList mean = new GraphDataArrayList();

        for (int i = 0; i < predictTimes.length(); i++) {
            long pointTime = predictTimes.at(i);
            CalendarSplitter.BucketPair pair = calendar.bucketPairOf(pointTime);

            SampleBucket left  = buckets[pair.left];
            SampleBucket right = buckets[pair.right];
            mean.add(pointTime, interpolate(left.getMean(), right.getMean(), pair.leftWeight, pair.rightWeight));
        }

        return mean.buildGraphData();
    }

    public GraphData predictVariance(LongArrayView predictTimes) {
        GraphDataArrayList var = new GraphDataArrayList();

        for (int i = 0; i < predictTimes.length(); i++) {
            long pointTime = predictTimes.at(i);
            CalendarSplitter.BucketPair pair = calendar.bucketPairOf(pointTime);

            SampleBucket left  = buckets[pair.left];
            SampleBucket right = buckets[pair.right];
            var.add(pointTime, interpolate(left.getVariance(), right.getVariance(), pair.leftWeight, pair.rightWeight));
        }

        return var.buildGraphData();
    }

    public GraphData predictAdjusted(GraphData source) {
        return predictAdjusted(source, 0, 0);
    }

    public GraphData predictAdjusted(GraphData source, final double relMinVariance, final double absMinVariance) {
        GraphDataArrayList adjusted = new GraphDataArrayList();

        source.visit((pointTime, value) -> {
            CalendarSplitter.BucketPair pair = calendar.bucketPairOf(pointTime);

            SampleBucket left  = buckets[pair.left];
            SampleBucket right = buckets[pair.right];

            double mean = interpolate(left.getMean(), right.getMean(), pair.leftWeight, pair.rightWeight);
            double minVariance = absMinVariance + relMinVariance * mean * mean;
            double var = interpolate(left.getVariance(), right.getVariance(), pair.leftWeight, pair.rightWeight);
            if (Double.isNaN(var) || var < minVariance)
                var = minVariance;

            double adjValue = (value - mean) / Math.sqrt(var);
            adjusted.add(pointTime, adjValue);
        });

        return adjusted.buildGraphData();
    }

}
