package ru.yandex.solomon.math.stat;

import java.time.Duration;

import javax.annotation.ParametersAreNonnullByDefault;

import ru.yandex.solomon.model.timeseries.GraphData;
import ru.yandex.solomon.model.timeseries.GraphDataArrayList;
import ru.yandex.solomon.util.collection.array.LongArrayView;

/**
 * @author Ivan Tsybulin
 */
@ParametersAreNonnullByDefault
public class KronosTrend {

    final private int smoothBuckets;
    final private int minBucketPoints;
    final private SampleBucketWithUnfilteredStats[] buckets;
    final private CalendarSplitter calendar;

    public KronosTrend(int intervalsPerDay) {
        this(intervalsPerDay, Duration.ofMinutes(5), 0.1, Duration.ofHours(3), 10);
    }

    public KronosTrend(int intervalsPerDay, Duration smooth, double dropFraction, Duration timeZone, int minBucketPoints) {
        this.calendar = new CalendarSplitter(intervalsPerDay, DailyProfile.WORK, timeZone);
        this.minBucketPoints = minBucketPoints;
        long bucketWidthMillis = calendar.getBucketWidthMillis();
        long smoothMillis = smooth.toMillis();
        smoothBuckets = (int)((smoothMillis + bucketWidthMillis - 1) / bucketWidthMillis);
        buckets = new SampleBucketWithUnfilteredStats[intervalsPerDay];
        for (int i = 0; i < buckets.length; i++) {
            buckets[i] = new SampleBucketWithUnfilteredStats(dropFraction);
        }
    }

    public void fit(GraphData fit) {
        GraphData noNanFitData = fit.filterNonNan();

        noNanFitData.visit((pointTime, value) -> {
            int bucketIndex = calendar.timeBucketOf(pointTime);

            int first = Math.max(0, bucketIndex - smoothBuckets);
            int last = Math.min(buckets.length, bucketIndex + smoothBuckets + 1);

            for (int i = first; i < last; i++) {
                buckets[i].add(value);
            }
        });
    }

    public GraphData predictMean(LongArrayView predictTimes) {
        GraphDataArrayList mean = new GraphDataArrayList();

        for (int i = 0; i < predictTimes.length(); i++) {
            long pointTime = predictTimes.at(i);
            int timeBucket = calendar.timeBucketOf(pointTime);
            int dayBucket = calendar.dayBucketOf(pointTime);
            SampleBucketWithUnfilteredStats bucket = buckets[timeBucket];

            if (bucket.getSize() < minBucketPoints) {
                mean.add(pointTime, Double.NaN);
            } else {
                if (dayBucket == 0) { // Work days
                    mean.add(pointTime, buckets[timeBucket].getMean());
                } else { // Weekends
                    mean.add(pointTime, buckets[timeBucket].getUnfilteredMean());
                }
            }
        }

        return mean.buildGraphData();
    }

    public GraphData predictVariance(LongArrayView predictTimes) {
        GraphDataArrayList var = new GraphDataArrayList();

        for (int i = 0; i < predictTimes.length(); i++) {
            long pointTime = predictTimes.at(i);
            int timeBucket = calendar.timeBucketOf(pointTime);
            int dayBucket = calendar.dayBucketOf(pointTime);
            SampleBucketWithUnfilteredStats bucket = buckets[timeBucket];

            if (bucket.getSize() < minBucketPoints) {
                var.add(pointTime, Double.NaN);
            } else {
                if (dayBucket == 0) { // Work days
                    var.add(pointTime, buckets[timeBucket].getVariance());
                } else { // Weekends
                    var.add(pointTime, buckets[timeBucket].getUnfilteredVariance());
                }
            }
        }

        return var.buildGraphData();
    }

    public GraphData predictAdjusted(GraphData source, double relMinVariance, double absMinVariance) {
        GraphDataArrayList adj = new GraphDataArrayList();

        source.visit((pointTime, value) -> {
            int timeBucket = calendar.timeBucketOf(pointTime);
            int dayBucket = calendar.dayBucketOf(pointTime);
            SampleBucketWithUnfilteredStats bucket = buckets[timeBucket];

            if (bucket.getSize() < minBucketPoints) {
                adj.add(pointTime, Double.NaN);
            } else {
                double mean, var;
                if (dayBucket == 0) { // Work days
                    mean = bucket.getMean();
                    var = bucket.getVariance();
                } else { // Weekends
                    mean = bucket.getUnfilteredMean();
                    var = bucket.getUnfilteredVariance();
                }
                double minVar = absMinVariance + relMinVariance * mean * mean;
                if (var < minVar) {
                    var = minVar;
                }
                adj.add(pointTime, (value - mean) / Math.sqrt(var));
            }
        });

        return adj.buildGraphData();
    }
}
