package ru.yandex.solomon.math.stat.trends;

import java.util.Arrays;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;

import ru.yandex.solomon.model.timeseries.GraphData;

/**
 * See https://stackoverflow.com/questions/17592139/trend-lines-regression-curve-fitting-java-library
 *
 * @author Oleg Baryshnikov
 */
@ParametersAreNonnullByDefault
public abstract class OLSTrendLine implements TrendLine {

    // will hold prediction coefs once we get values
    @Nullable
    private final RealMatrix coef;

    // create vector of values from x
    protected abstract double[] xVector(long x);

    // set true to predict log of y (note: y must be positive)
    protected abstract boolean logY();

    public OLSTrendLine(GraphData graphData) {
        this.coef = constructConf(graphData);
    }

    @Nullable
    private RealMatrix constructConf(GraphData graphData) {
        try {
            long[] xArray = graphData.getTimestamps().array;
            double[] yArray = graphData.getValues().array;

            double[][] xData = new double[xArray.length][];
            for (int i = 0; i < xArray.length; i++) {
                // the implementation determines how to produce a vector of predictors from a single x
                xData[i] = xVector(xArray[i]);
            }
            if (logY()) {
                // in some models we are predicting ln y, so we replace each y with ln y
                yArray = Arrays.copyOf(yArray, yArray.length);
                for (int i = 0; i < xArray.length; i++) {
                    yArray[i] = Math.log(yArray[i]);
                }
            }

            OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();
            // let the implementation include a constant in xVector if desired
            ols.setNoIntercept(true);
            // provide the data to the model
            ols.newSampleData(yArray, xData);
            // get our coefficients
            return MatrixUtils.createColumnRealMatrix(ols.estimateRegressionParameters());
        } catch (Exception e) {
           return null;
        }
    }

    @Override
    public boolean canPredict() {
        return coef != null;
    }

    @Override
    public double predict(long x) {
        if (coef == null) {
            // cannot create trend line by source data, return NaNs only
            return Double.NaN;
        }

        double[] v = xVector(x);
        double yhat = coef.preMultiply(v)[0]; // apply coefs to xVector
        if (logY()) {
            // if we predicted ln y, we still need to get y
            yhat = (Math.exp(yhat));
        }
        return yhat;
    }
}
