package ru.yandex.solomon.util.async;

import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

import org.junit.Assert;
import org.junit.Test;

/**
 * @author Sergey Polovko
 */
public class InFlightLimiterTest {

    @Test
    public void singleThread() throws InterruptedException {
        var limiter = new InFlightLimiter(2);
        var latch = new CountDownLatch(3);

        limiter.run(() -> runAfter(1000, latch::countDown));
        limiter.run(() -> runAfter(2000, latch::countDown));
        limiter.run(() -> runAfter(3000, latch::countDown));

        Assert.assertEquals(2, limiter.getCurrent());
        Assert.assertEquals(1, limiter.getWaitingCount());

        latch.await(5, TimeUnit.SECONDS);

        Assert.assertEquals(0, limiter.getCurrent());
        Assert.assertEquals(0, limiter.getWaitingCount());
    }

    @Test
    public void multiThreaded() throws Exception {
        final int threads = 4;
        final int iterations = 100;

        var executor = Executors.newFixedThreadPool(threads);
        var futures = new ArrayList<Future<?>>(threads);

        var limiter = new InFlightLimiter(2);
        var latch = new CountDownLatch(threads * iterations);

        for (int i = 0; i < threads; i++) {
            futures.add(executor.submit(() -> {
                for (int j = 0; j < iterations; j++) {
                    long delayMillis = ThreadLocalRandom.current().nextLong(10, 100);
                    limiter.run(() -> runAfter(delayMillis, latch::countDown));
                }
            }));
        }

        // await all threads done
        executor.shutdown();
        executor.awaitTermination(15, TimeUnit.SECONDS);

        // rethrow exceptions if any
        for (Future<?> f : futures) {
            f.get();
        }

        // await all async operations
        latch.await(15, TimeUnit.SECONDS);

        Assert.assertEquals(0, limiter.getCurrent());
        Assert.assertEquals(0, limiter.getWaitingCount());
    }

    private static CompletableFuture<Void> runAfter(long delayMillis, Runnable fn) {
        return CompletableFuture.runAsync(fn, CompletableFuture.delayedExecutor(delayMillis, TimeUnit.MILLISECONDS));
    }
}
