package ru.yandex.solomon.coremon.meta.db;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import org.junit.Test;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.misc.test.Assert;

import static org.junit.Assert.assertEquals;

/**
 * @author Vladimir Gordiychuk
 */
public class ShardIdToAtomicIntegerTest {

    @Test
    public void setGet() {
        var map = new ShardIdToAtomicInteger();
        for (int index = 1; index < 10_000; index++) {
            map.set(index, index * 10);
        }

        for (int index = 1; index < 10_000; index++) {
            Assert.assertEquals(index * 10, map.get(index));
        }
    }

    @Test
    public void incrementGet() {
        var map = new ShardIdToAtomicInteger();
        for (int index = 1; index < 10_000; index++) {
            map.incrementAndGet(index);
        }

        for (int index = 1; index < 10_000; index++) {
            assertEquals(1, map.get(index));
        }
    }

    @Test
    public void get() {
        var map = new ShardIdToAtomicInteger();
        for (int index = 1; index < 10_000; index++) {
            assertEquals(0, map.get(index));
        }
    }

    @Test
    public void concurrentSet() {
        var map = new ShardIdToAtomicInteger();
        int sum = IntStream.range(1, 10_000)
                .parallel()
                .peek(idx -> map.set(idx, idx * 10))
                .sum();

        for (int index = 1; index < 10_000; index++) {
            assertEquals(index * 10, map.get(index));
        }
    }

    @Test
    public void concurrentIncrement() {
        var shardCount = 5_000;
        var map = new ShardIdToAtomicInteger();
        CyclicBarrier barrier = new CyclicBarrier(4);
        Supplier<int[]> fn = () -> {
            try {
                int[] result = new int[shardCount];
                var random = ThreadLocalRandom.current();
                barrier.await();

                for (int index = 0; index < 100_000; index++) {
                    int idx = random.nextInt(1, result.length);
                    map.incrementAndGet(idx);
                    result[idx]++;
                }
                return result;
            } catch (Throwable e) {
                throw new RuntimeException(e);
            }
        };

        var futures = List.of(
                CompletableFuture.supplyAsync(fn),
                CompletableFuture.supplyAsync(fn),
                CompletableFuture.supplyAsync(fn),
                CompletableFuture.supplyAsync(fn));

        var resultPerFuture = CompletableFutures.allOf(futures).join();
        int[] expected = new int[shardCount];
        for (var stats : resultPerFuture) {
            for (int index = 0; index < expected.length; index++) {
                expected[index] += stats[index];
            }
        }

        for (int index = 1; index < shardCount; index++) {
            assertEquals(expected[index], map.get(index));
        }
    }
}
