package ru.yandex.solomon.math.stat;


import java.util.Arrays;

import org.apache.commons.math3.transform.DftNormalization;
import org.apache.commons.math3.transform.FastFourierTransformer;
import org.apache.commons.math3.transform.TransformType;
import org.apache.commons.math3.util.FastMath;

import ru.yandex.solomon.model.timeseries.GraphData;
import ru.yandex.solomon.model.timeseries.SortedOrCheck;
import ru.yandex.solomon.util.collection.array.DoubleArrayView;
import ru.yandex.solomon.util.collection.array.LongArrayView;

/**
 * @see <a href="http://futuredata.stanford.edu/asap/">ASAP: Prioritizing Attention via Time Series Smoothing</a>
 * @see <a href="https://en.wikipedia.org/wiki/Kurtosis">Kurtosis</a>
 */
public class ASAP {
    private static final double ACF_THRESH = 0.2;
    private static final double KURT_THRESH = 90.0;
    // Search for window sizes
    private static final int MAX_WINDOW = 10;

    public static GraphData smooth(GraphData data) {
        return smoothIteration(smoothIteration(data));
    }

    private static GraphData smoothIteration(GraphData data) {
        if (data.length() < MAX_WINDOW) {
            return data;
        }

        var work = new Data(data);
        work.dropLeadingNans();
        work.dropTrailingNans();
        if (work.timestamps.length() < MAX_WINDOW) {
            return data;
        }

        var source = work.prepare();
        double originalKurt = source.kurtosis();

        // Only search for time series whose initial Kurtosis is not too large
        if (originalKurt > KURT_THRESH || Double.isNaN(originalKurt)) {
            return data;
        }

        Opts opts = new Opts(source);
        Source best = findWindow(source, opts);
        if (opts.window == 1) {
            return data;
        }

        return work.prepareResult(best, opts.window);
    }

    private static void searchPeriodic(Source source, Autocorrelation acf, Opts opts) {
        for (int i = acf.peakIndexes.length - 1; i >= 0; i--) {
            int peekIndex = acf.peakIndexes[i];
            if (peekIndex == opts.window) {
                continue;
            } else if (peekIndex < opts.lowerBoundWindow || peekIndex == 1) {
                break;
            } else if (roughnessGreaterThanOpt(peekIndex, acf, opts)) {
                continue;
            }

            Source smoothed = source.sma(peekIndex);
            double kurtosis = smoothed.kurtosis();
            if (Double.compare(kurtosis, opts.kurtosis) < 0) {
                continue;
            }

            double roughness = smoothed.roughness();
            if (Double.compare(roughness, opts.roughness) < 0) {
                opts.roughness = roughness;
                opts.window = peekIndex;
            }

            opts.largetFeasibleIdx = Math.max(opts.largetFeasibleIdx, i);
            opts.lowerBoundWindow = updateLB(acf, opts.lowerBoundWindow, peekIndex);
        }
    }

    private static Source findWindow(Source source, Opts opts) {
        Autocorrelation candidates = new Autocorrelation(source);
        if (opts.window > 1) {
            opts.lowerBoundWindow = updateLB(candidates, opts.lowerBoundWindow, opts.window);
        }
        searchPeriodic(source, candidates, opts);
        int head = opts.lowerBoundWindow;
        int tail = source.values.length / MAX_WINDOW;
        if (opts.largetFeasibleIdx > 0) {
            if (opts.largetFeasibleIdx < candidates.peakIndexes.length - 2) {
                tail = candidates.peakIndexes[opts.largetFeasibleIdx + 1];
            }
            head = Math.max(head, candidates.peakIndexes[opts.largetFeasibleIdx] + 1);
        }
        return binarySearch(source, head, tail, opts);
    }

    /**
     * Check whether the current choice of window size will produce a time series
     * rougher than the current optimal
     * */
    private static boolean roughnessGreaterThanOpt(int w, Autocorrelation acf, Opts opts) {
        if (opts.window >= acf.correlations.length) {
            return false;
        }

        return Math.sqrt(1 - acf.correlations[w]) * opts.window > Math.sqrt(1 - acf.correlations[opts.window]) * w;
    }

    /**
     * Update lower bound of the window size to search
     * */
    private static int updateLB(Autocorrelation acf, int lowerBoundWindow, int w) {
        if (acf.correlations.length <= w) {
            return lowerBoundWindow;
        }

        return (int) Math.round(Math.max(w * Math.sqrt((acf.maxACF - 1) / (acf.correlations[w] - 1)), lowerBoundWindow));
    }

    private static Source binarySearch(Source source, int head, int tail, Opts opts) {
        Source result = source.sma(opts.window);
        while (head <= tail) {
            int mid = (head + tail) / 2;
            Source smoothed = source.sma(mid);
            double kurtosis = smoothed.kurtosis();
            if (Double.compare(kurtosis, opts.kurtosis) >= 0) {
                /* Search second half if feasible */
                double roughness = smoothed.roughness();
                if (Double.compare(roughness, opts.roughness) < 0) {
                    result = smoothed;
                    opts.roughness = roughness;
                    opts.window = mid;
                }
                head = mid + 1;
            } else {
                /* Search first half */
                tail = mid - 1;
            }
        }
        return result;
    }

    private static class Source {
        private double[] values;
        private double mean;
        private double kurtosis;

        Source(double[] values, double mean) {
            this.values = values;
            this.mean = mean;
        }

        double getMean() {
            return mean;
        }

        double kurtosis() {
            if (!Double.isNaN(kurtosis)) {
                return kurtosis;
            }

            double mean = getMean();
            double u4 = 0;
            double variance = 0;
            for (double value : values) {
                double v = value - mean;
                double pow2 = v * v;
                u4 += pow2 * pow2;
                variance += pow2;
            }

            kurtosis = values.length * u4 / (variance * variance);
            return kurtosis;
        }

        double roughness() {
            double sum = 0;
            double[] diff = new double[this.values.length - 1];
            for (int i = 1; i < values.length; i++) {
                double d = this.values[i] - this.values[i - 1];
                diff[i - 1] = d;
                sum += d;
            }

            double mean = sum / diff.length;
            double std = 0;
            for (double d : diff) {
                double value = d - mean;
                std += value * value;
            }
            return Math.sqrt(std / values.length);
        }

        double[] getVarianceArray() {
            int n = values.length;
            double m = getMean();
            // Pad with 0
            final int padding = (int) FastMath.pow(2, 32 - Integer.numberOfLeadingZeros(2 * n - 1));
            double[] result = new double[padding];
            // zero mean data
            for (int i = 0; i < n; i++) {
                result[i] = values[i] - m;
            }
            return result;
        }

        Source sma(int points) {
            if (points == 1 || values.length <= points) {
                return this;
            }

            double windowSum = 0;
            int indexWindowStart = 0;

            for (int index = 0; index < points; index++) {
                windowSum += values[index];
            }

            int indexResult = 0;
            double totalSum = 0;
            double[] result = new double[values.length - points];
            for (int index = points; index < values.length; index++) {
                double avg = windowSum / points;
                result[indexResult++] = avg;
                totalSum += avg;
                windowSum += values[index] - values[indexWindowStart++];
            }

            return new Source(result, totalSum / result.length);
        }
    }

    private static class Autocorrelation {
        final double[] correlations;
        final double maxACF;
        final int[] peakIndexes;

        public Autocorrelation(Source source) {
            int maxLag = source.values.length / MAX_WINDOW;

            double[] varience = source.getVarianceArray();
            double[][] dataRI = new double[][] {varience, new double[varience.length]};

            /* F_R(f) = FFT(X) */
            FastFourierTransformer.transformInPlace(dataRI, DftNormalization.STANDARD, TransformType.FORWARD);
            /* S(f) = F_R(f)F_R*(f) */
            for (int i = 0; i < dataRI[0].length; i += 1) {
                dataRI[0][i] = (dataRI[0][i] * dataRI[0][i]) + (dataRI[1][i] * dataRI[1][i]);
                dataRI[1][i] = 0;
            }
            /*  R(t) = IFFT(S(f)) */
            FastFourierTransformer.transformInPlace(dataRI, DftNormalization.STANDARD, TransformType.INVERSE);

            correlations = new double[maxLag];
            for (int i = 1; i < maxLag; i++) {
                correlations[i] = dataRI[0][i] / dataRI[0][0];
            }

            int peakSize = 0;
            int[] peaks = new int[correlations.length];
            int max = 1;
            double maxACF = 0;
            if (correlations.length > 1) {
                boolean positive = (correlations[1] > correlations[0]);
                for (int i = 2; i < correlations.length; i++) {
                    if (!positive && correlations[i] > correlations[i - 1]) {
                        max = i;
                        positive = true;
                    } else if (positive && correlations[i] > correlations[max]) {
                        max = i;
                    } else if (positive && correlations[i] < correlations[i - 1]) {
                        if (max > 1 && correlations[max] > ACF_THRESH) {
                            peaks[peakSize++] = max;
                            if (correlations[max] > maxACF) {
                                maxACF = correlations[max];
                            }
                        }
                        positive = false;
                    }
                }
            }
            this.maxACF = maxACF;
            this.peakIndexes = Arrays.copyOf(peaks, peakSize);
        }
    }

    private static class Opts {
        double roughness;
        double kurtosis;
        int lowerBoundWindow = 1;
        int window = 1;
        int largetFeasibleIdx = -1;

        Opts(Source source) {
            this.roughness = source.roughness();
            this.kurtosis = source.kurtosis();
        }
    }

    private static class Data {
        private LongArrayView timestamps;
        private DoubleArrayView values;

        public Data(GraphData graphData) {
            this.timestamps = graphData.getTimestamps();
            this.values = graphData.getValues();
        }

        void dropLeadingNans() {
            int index = 0;
            for(; index < values.length(); index++) {
                if (!Double.isNaN(values.at(index))) {
                    break;
                }
            }

            timestamps = timestamps.slice(index, timestamps.length());
            values = values.slice(index, values.length());
        }

        void dropTrailingNans() {
            int index = values.length() - 1;
            for (; index >= 0; index--) {
                if (!Double.isNaN(values.at(index))) {
                    break;
                }
            }

            timestamps = timestamps.slice(0, index + 1);
            values = values.slice(0, index + 1);
        }

        Source prepare() {
            double sum = 0;
            double[] data = values.copyToArray();
            for (int index = 0; index < data.length; index++) {
                var value = data[index];
                if (Double.isNaN(value)) {
                    data[index] = 0;
                } else {
                    sum += value;
                }
            }
            return new Source(data, sum / data.length);
        }

        GraphData prepareResult(Source best, int window) {
            timestamps = timestamps.slice(window, timestamps.length());
            return new GraphData(timestamps.copyOrArray(), best.values, SortedOrCheck.SORTED_UNIQUE);
        }
    }
}
