package ru.yandex.solomon.math.stat;

import java.util.Arrays;

import javax.annotation.concurrent.NotThreadSafe;

import org.apache.commons.math3.util.FastMath;

/**
 * Caching quickselect algorithm extractracted from {@link org.apache.commons.math3.stat.descriptive.rank.Percentile} with minor changes
 * @author Vladimir Gordiychuk
 */
@NotThreadSafe
public class QuickSelect {
    /** Minimum selection size for insertion sort rather than selection. */
    private static final int MIN_SELECT_SIZE = 15;

    /** Maximum number of partitioning pivots cached (each level double the number of pivots). */
    private static final int MAX_CACHED_LEVELS = 10;

    /** Maximum number of cached pivots in the pivots cached array */
    private static final int PIVOTS_HEAP_LENGTH = (0x1 << MAX_CACHED_LEVELS) - 1;

    private int[] cachedPivots = null;

    /**
     * Invalidates the pivots cache.
     *
     * Allocates cachedPivots if it's not done yet
     */
    public void clearPivotsCache() {
        if (cachedPivots == null) {
            cachedPivots = new int[PIVOTS_HEAP_LENGTH];
        }
        Arrays.fill(cachedPivots, -1);
    }

    /**
     * Select K<sup>th</sup> value in the array.
     *
     * @param work work array to use to find out the K<sup>th</sup> value
     * @param k the index whose value in the array is of interest
     * @param usePivotsCache whether to use pivots cache
     * @return K<sup>th</sup> value
     *
     * If <code>usePivotsCache</code> is <code>false</code> the cache is left intact.
     * Otherwise the internal cache will be populated with found pivotes.
     * The cache remains valid only for subsequent calls to <code>select()</code> method
     * for the same <code>work</code> array. If the <code>work</code> array is modified then
     * an explicit call to <code>clearPivotsCache()</code> is neccessary.
     */
    public double select(final double[] work, final int start, final int length, boolean usePivotsCache, final int k) {
        int begin = start;
        int end = start + length;
        int node = 0;
        if (usePivotsCache && cachedPivots == null)
            clearPivotsCache();
        while (end - begin > MIN_SELECT_SIZE) {
            final int pivot;

            if (usePivotsCache && node < cachedPivots.length &&
                cachedPivots[node] >= 0) {
                // the pivot has already been found in a previous call
                // and the array has already been partitioned around it
                pivot = cachedPivots[node];
            } else {
                // select a pivot and partition work array around it
                pivot = partition(work, begin, end, medianOf3(work, begin, end));
                if (usePivotsCache && node < cachedPivots.length) {
                    cachedPivots[node] = pivot;
                }
            }

            if (k == pivot) {
                // the pivot was exactly the element we wanted
                return work[k];
            } else if (k < pivot) {
                // the element is in the left partition
                end  = pivot;
                node = FastMath.min(2 * node + 1, usePivotsCache ? cachedPivots.length : end);
            } else {
                // the element is in the right partition
                begin = pivot + 1;
                node  = FastMath.min(2 * node + 2, usePivotsCache ? cachedPivots.length : end);
            }
        }
        Arrays.sort(work, begin, end);
        return work[k];
    }

    /**
     * Partition an array slice around a pivot. Partitioning exchanges array
     * elements such that all elements smaller than pivot are before it and
     * all elements larger than pivot are after it.
     *
     * @param work work array
     * @param begin index of the first element of the slice of work array
     * @param end index after the last element of the slice of work array
     * @param pivot initial index of the pivot
     * @return index of the pivot after partition
     */
    private int partition(final double[] work, final int begin, final int end, final int pivot) {
        final double value = work[pivot];
        work[pivot] = work[begin];

        int i = begin + 1;
        int j = end - 1;
        while (i < j) {
            while (i < j && work[j] > value) {
                --j;
            }
            while (i < j && work[i] < value) {
                ++i;
            }

            if (i < j) {
                final double tmp = work[i];
                work[i++] = work[j];
                work[j--] = tmp;
            }
        }

        if (i >= end || work[i] > value) {
            --i;
        }
        work[begin] = work[i];
        work[i] = value;
        return i;
    }

    /** Select a pivot index as the median of three
     * @param work data array
     * @param begin index of the first element of the slice
     * @param end index after the last element of the slice
     * @return the index of the median element chosen between the
     * first, the middle and the last element of the array slice
     */
    private int medianOf3(final double[] work, final int begin, final int end) {

        final int inclusiveEnd = end - 1;
        final int    middle    = begin + (inclusiveEnd - begin) / 2;
        final double wBegin    = work[begin];
        final double wMiddle   = work[middle];
        final double wEnd      = work[inclusiveEnd];

        if (wBegin < wMiddle) {
            if (wMiddle < wEnd) {
                return middle;
            } else {
                return (wBegin < wEnd) ? inclusiveEnd : begin;
            }
        } else {
            if (wBegin < wEnd) {
                return begin;
            } else {
                return (wMiddle < wEnd) ? inclusiveEnd : middle;
            }
        }
    }
}
