package ru.yandex.market.clickphite.monitoring.kronos;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import org.apache.commons.math3.stat.StatUtils;

import java.util.Arrays;
import java.util.List;

/**
 * @author Dmitry Andreev <a href="mailto:AndreevDm@yandex-team.ru"></a>
 * @date 28/10/15
 */
public class KronosTrainer {

    private static final int MIN_POINTS_COUNT = 10;

    private static final double RED_TOP_MAGIC = 21;
    private static final double YELLOW_TOP_MAGIC = 10;
    private static final double YELLOW_BOTTOM_MAGIC = 3;
    private static final double RED_BOTTOM_MAGIC = 7;
    private static final double MIN_MULT = 0.1;

    private static final double MAX_VALUES_RATIO = 1.5;

    private KronosTrainer() {
    }

    public static KronosModel estimateConfidenceIntervals(List<double[]> data, int smoothness) {
        return estimateConfidenceIntervals(data, smoothness, 1.0);
    }

    public static KronosModel estimateConfidenceIntervals(List<double[]> data, int smoothness, double mult) {


        int pointCount = data.get(0).length;

        double[] extTop = new double[pointCount];
        double[] top = new double[pointCount];
        double[] average = new double[pointCount];
        double[] mean = new double[pointCount];
        double[] bottom = new double[pointCount];
        double[] extBottom = new double[pointCount];

        double commonMean = StatUtils.mean(getPoints(data));

        for (int i = 0; i < pointCount; i++) {
            double[] points = getPoints(data, i, smoothness);
            if (points.length < MIN_POINTS_COUNT) {
                extTop[i] = Double.NaN;
                top[i] = Double.NaN;
                average[i] = Double.NaN;
                bottom[i] = Double.NaN;
                extBottom[i] = Double.NaN;
                continue;
            }
            average[i] = Arrays.stream(points).average().getAsDouble();
            mean[i] = StatUtils.mean(points);

            double middle = mean[i];

            double sigma = calcSigma(points, middle);
            double meanDiff = Math.log10(commonMean / middle);

            double diffMult = Math.max(MIN_MULT, mult + meanDiff);

            extTop[i] = middle + RED_TOP_MAGIC * sigma * diffMult;
            top[i] = middle + YELLOW_TOP_MAGIC * sigma * diffMult;
            bottom[i] = Math.max(
                middle / 100,
                middle - YELLOW_BOTTOM_MAGIC * sigma * diffMult
            );
            extBottom[i] = Math.max(
                middle / 100,
                middle - RED_BOTTOM_MAGIC * sigma * diffMult
            );
        }
        return new KronosModel(extTop, top, average, mean, bottom, extBottom);
    }

    private static double calcSigma(double[] points, double middle) {
        double sigma2 = 0;
        for (int j = 0; j < points.length; j++) {
            sigma2 += Math.pow(points[j] - middle, 2);
        }
        return Math.sqrt(sigma2 / (points.length - 1));
    }

    private static double[] getPoints(List<double[]> data, int idx, int smoothness) {
        DoubleList points = new DoubleArrayList(data.size() * 2);
        for (int i = 0; i < data.size(); i++) {
            double[] trainPoints = data.get(i);
            for (int j = Math.max(idx - smoothness, 0); j < Math.min(idx + 1 + smoothness, trainPoints.length); j++) {
                double point = trainPoints[j];
                if (!Double.isNaN(point)) {
                    points.add(point);
                }
            }
        }
        cleanAnomaly(points);
        return points.toDoubleArray();
    }

    private static double[] getPoints(List<double[]> data) {
        DoubleList points = new DoubleArrayList();
        for (int i = 0; i < data.size(); i++) {
            double[] trainPoints = data.get(i);
            for (int j = 0; j < trainPoints.length; j++) {
                double point = trainPoints[j];
                if (!Double.isNaN(point)) {
                    points.add(point);
                }
            }
        }
        return points.toDoubleArray();
    }

    private static void cleanAnomaly(DoubleList values) {
        if (values.isEmpty()) {
            return;
        }

        double min;
        double max;
        double average;
        boolean isAppropriate;

        int removesCount = 0;
        int size = values.size();

        do {

            min = values.get(0);
            max = values.get(0);
            average = 0;
            for (double value : values) {
                if (value < min) {
                    min = value;
                }
                if (value > max) {
                    max = value;
                }
                average += value;
            }
            average /= values.size();

            isAppropriate = (max < min * MAX_VALUES_RATIO);

            if (!isAppropriate) {
                double toRemove = (max - average) > (average - min) ? max : min;
                values.remove(toRemove);
                removesCount++;
            }
        } while (!isAppropriate);
    }


}
