package ru.yandex.tools.benchmark;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.function.Supplier;

@SuppressWarnings("try")
public class Benchmark implements AutoCloseable {
    private static final long JIT_INTERVAL = 100L;
    private static final String OF = " of ";

    private final BenchmarkOptions options;
    private final List<? extends Supplier<? extends Callable<?>>> benchs;

    public Benchmark(
        final BenchmarkOptions options,
        final List<? extends Supplier<? extends Callable<?>>> benchs)
    {
        this.options = options;
        this.benchs = benchs;
    }

    @Override
    public void close() throws Exception {
        Exception e = null;
        for (Object bench: benchs) {
            if (bench instanceof AutoCloseable) {
                try {
                    ((AutoCloseable) bench).close();
                } catch (Exception ex) {
                    if (e == null) {
                        e = ex;
                    } else {
                        e.addSuppressed(ex);
                    }
                }
            }
        }
        if (e != null) {
            throw e;
        }
    }

    private List<Callable<?>> benchs() {
        List<Callable<?>> benchs = new ArrayList<>(this.benchs.size());
        for (Supplier<? extends Callable<?>> bench: this.benchs) {
            benchs.add(bench.get());
        }
        return benchs;
    }

    private void warmup() throws Exception {
        List<Callable<?>> benchs = benchs();
        for (int i = 0; i < options.warmupIterations(); ++i) {
            for (Callable<?> bench: benchs) {
                bench.call();
            }
            if (options.shouldPrint()) {
                options.out().println(
                    i + 1 + OF + options.warmupIterations()
                    + " warmup operations completed");
            }
        }
        if (options.out() != null) {
            options.out().println("Benchmarks warmup completed");
        }
        BenchmarkOptions warmupOptions = new BenchmarkOptions()
            .out(null)
            .gcSleep(0L)
            .repeats(options.warmupRepeats())
            .benchmarkIterations(options.benchmarkIterations() >> 2);
        BenchmarkCycle warmupCycle =
            new BenchmarkCycle(warmupOptions, this.benchs);
        for (int i = 0; i < options.warmupCycles(); ++i) {
            warmupCycle.benchmark();
            if (options.shouldPrint()) {
                options.out().println(
                    i + 1 + OF + options.warmupCycles()
                    + " warmup cycles completed");
            }
        }
        warmupCycle = new BenchmarkCycle(
            warmupOptions.gcSleep(options.gcSleep()),
            this.benchs);
        int cycles = options.warmupCycles();
        long[] timings = new long[benchs.size()];
        Arrays.fill(timings, Long.MAX_VALUE);
        for (int i = 0; i < cycles; ++i) {
            if (mergeTimings(timings, warmupCycle.benchmark())) {
                ++cycles;
            }
            if (options.shouldPrint()) {
                options.out().println(
                    i + 1 + OF + cycles
                    + " second warmup cycles completed");
            }
        }
        if (options.out() != null) {
            options.out().println("Warmup results:");
            for (int i = 0; i < benchs.size(); ++i) {
                System.out.println(
                    benchs.get(i) + " taken "
                    + String.format(
                        "%.3f ns per iteration",
                        ((double) timings[i]) / options.warmupRepeats()));
            }
        }
    }

    public static boolean mergeTimings(final long[] dst, final long[] src) {
        boolean timingImproved = false;
        for (int i = 0; i < dst.length; ++i) {
            if (dst[i] > src[i]) {
                timingImproved = true;
                dst[i] = src[i];
            }
        }
        return timingImproved;
    }

    public long[] benchmark() throws Exception {
        warmup();
        if (options.out() != null) {
            options.out().println("Benchmark warmup completed");
        }
        long[] timings = new long[benchs.size()];
        Arrays.fill(timings, Long.MAX_VALUE);
        BenchmarkCycle cycle = new BenchmarkCycle(options, benchs);
        int cyclesLeft = benchs.size();
        do {
            long[] cycleTimings = cycle.benchmark();
            if (mergeTimings(timings, cycleTimings)) {
                ++cyclesLeft;
            }
        } while (--cyclesLeft > 0);
        benchmarkCompleted(benchs(), timings);
        return timings;
    }

    private void benchmarkCompleted(
        final List<Callable<?>> benchs,
        final long[] timings)
    {
        if (options.out() != null) {
            int maxLength = benchs.get(0).toString().length();
            for (int i = 1; i < benchs.size(); ++i) {
                int length = benchs.get(i).toString().length();
                if (length > maxLength) {
                    maxLength = length;
                }
            }
            List<Result> results = new ArrayList<>(benchs.size());
            for (int i = 0; i < benchs.size(); ++i) {
                results.add(new Result(benchs.get(i), timings[i]));
            }
            Collections.sort(results);
            Result worst = results.get(benchs.size() - 1);
            int worstTimeLength =
                Long.toString(worst.timing() / options.repeats()).length();
            options.out().println("Benchmark results:");
            final int percents = 100;
            for (int i = 0; i < benchs.size(); ++i) {
                double timing = results.get(i).timing();
                options.out().println(
                    String.format(
                        "%-" + maxLength + "s taken %" + worstTimeLength
                        + ".3f ns for iteration which is %3.2f%% of worst"
                        + " result",
                        results.get(i).bench(),
                        timing / options.repeats(),
                        timing * percents / worst.timing()));
            }
        }
    }

    private static class Result implements Comparable<Result> {
        private final Callable<?> bench;
        private final long timing;

        Result(final Callable<?> bench, final long timing) {
            this.bench = bench;
            this.timing = timing;
        }

        public Callable<?> bench() {
            return bench;
        }

        public long timing() {
            return timing;
        }

        @Override
        public int compareTo(final Result other) {
            return Long.compare(timing, other.timing);
        }

        @Override
        public int hashCode() {
            return bench.hashCode();
        }

        @Override
        public boolean equals(final Object o) {
            return o instanceof Result && bench.equals(((Result) o).bench);
        }
    }
}

