package ru.yandex.solomon.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.LongStream;

import com.google.common.collect.Comparators;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Threads;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;

import ru.yandex.misc.random.Random2;
import ru.yandex.solomon.util.text.TextWithNumbersComparator;


/**
 * Results from solomon-dev-myt-00.search.yandex.net
 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 *
 * Benchmark                         Mode  Cnt    Score    Error  Units
 * TopKBenchmark.least              thrpt   10  237.332 ± 26.473  ops/s
 * TopKBenchmark.priorityQueue      thrpt   10   35.411 ±  0.970  ops/s
 * TopKBenchmark.quickSelect        thrpt   10  145.865 ±  5.409  ops/s
 * TopKBenchmark.quickSelectNoSort  thrpt   10  156.396 ±  7.649  ops/s
 * TopKBenchmark.sortLimit          thrpt   10   41.368 ±  0.154  ops/s
 *
 * @author Sergey Polovko
 */
@Fork(value = 1)
@Measurement(iterations = 10, time = 3, timeUnit = TimeUnit.SECONDS)
@Warmup(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS)
@State(Scope.Thread)
@Threads(1)
@BenchmarkMode(Mode.Throughput)
public class TopKBenchmark {

    private static final int MAX_COUNT = 1_000;

    private static final Random2 RND = new Random2(17);
    private static final Set<String> STRINGS = LongStream.range(1, 100_000)
        .mapToObj(i -> RND.nextString(5))
        .collect(Collectors.toSet());

    @Benchmark
    public List<String> sortLimit() {
        return STRINGS.stream()
            .sorted(TextWithNumbersComparator.instance)
            .limit(MAX_COUNT)
            .collect(Collectors.toList());
    }

    @Benchmark
    public List<String> least() {
        return STRINGS.stream()
            .collect(Comparators.least(MAX_COUNT, TextWithNumbersComparator.instance));
    }

    @Benchmark
    public List<String> priorityQueue() {
        PriorityQueue<String> pq = new PriorityQueue<>(MAX_COUNT, TextWithNumbersComparator.instance);
        for (String s : STRINGS) {
            pq.offer(s);
            if (pq.size() > MAX_COUNT) {
                pq.poll();
            }
        }
        ArrayList<String> result = new ArrayList<>(MAX_COUNT);
        while (!pq.isEmpty()) {
            result.add(pq.poll());
        }
        return result;
    }

    @Benchmark
    public String[] quickSelect() {
        String[] strings = STRINGS.toArray(new String[0]);
        quickSelect(strings, MAX_COUNT, TextWithNumbersComparator.instance);
        Arrays.sort(strings, 0, MAX_COUNT);
        return Arrays.copyOf(strings, MAX_COUNT);
    }

    @Benchmark
    public String[] quickSelectNoSort() {
        String[] strings = STRINGS.toArray(new String[0]);
        quickSelect(strings, MAX_COUNT, TextWithNumbersComparator.instance);
        return Arrays.copyOf(strings, MAX_COUNT);
    }

    /**
     * Iterative version of <a href="https://en.wikipedia.org/wiki/Quickselect">Quickselect</a>
     */
    private static <T> void quickSelect(T[] items, int k, Comparator<T> cmp) {
        if (items.length <= k) {
            throw new IllegalArgumentException("too small items");
        }

        int left = 0, right = items.length - 1;
        while (left < right) {
            T mid = items[(left + right) / 2];

            int i = left, j = right;
            while (i < j) {
                if (cmp.compare(items[i], mid) >= 0) {
                    T tmp = items[j];
                    items[j] = items[i];
                    items[i] = tmp;
                    j--;
                } else {
                    i++;
                }
            }

            if (cmp.compare(items[i], mid) > 0) {
                i--;
            }

            if (k <= i) {
                right = i;
            } else {
                left = i + 1;
            }
        }
    }

    public static void main(String[] args) throws RunnerException {
        Options opt = new OptionsBuilder()
            .include(TopKBenchmark.class.getName())
            .detectJvmArgs()
            .jvmArgs("-server", "-Xmx2g", "-Xms2g")
            .build();

        new Runner(opt).run();
    }
}
