package ru.yandex.solomon.coremon.meta.db;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.IntSummaryStatistics;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ForkJoinPool;

import com.google.common.collect.ImmutableSet;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.solomon.metrics.client.StockpileClientStub;
import ru.yandex.stockpile.client.shard.StockpileShardId;

import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThat;


/**
 * @author Sergey Polovko
 */
public class StockpileMetricIdProviderStatsAwareTest {

    private ShardIdToAtomicInteger stats;
    private StockpileClientStub stockpile;

    @Before
    public void setUp() {
        stats = new ShardIdToAtomicInteger();
        stockpile = new StockpileClientStub(ForkJoinPool.commonPool());
    }

    @After
    public void tearDown() {
        stockpile.close();
    }

    @Test
    public void underLimit() {
        final int limit = 10;

        var provider = provider("shard", "some-host", limit);

        Set<Integer> shardIds = new HashSet<>();
        for (int i = 0; i < limit; i++) {
            shardIds.add(provider.shardId(Labels.empty()));
        }

        Assert.assertEquals(1, shardIds.size());

        int shardId = shardIds.iterator().next();
        int metricsCount = stats.get(shardId);

        Assert.assertEquals(limit, metricsCount);
    }

    @Test
    public void fewAboveLimit() {
        final int limit = 10;

        var provider = provider("shard", "some-host", limit);

        Set<Integer> shardIds = new HashSet<>();
        for (int i = 0; i < limit + 3; i++) {
            shardIds.add(provider.shardId(Labels.empty()));
        }

        Set<Integer> expectedSizes = ImmutableSet.of(limit, 3);
        Set<Integer> actualSizes = new HashSet<>();

        for (int shardId : shardIds) {
            actualSizes.add(stats.get(shardId));
        }

        Assert.assertEquals(expectedSizes, actualSizes);
    }

    @Test
    public void manyAboveLimit() {
        final int limit = 10;
        final int count = limit * stockpile.getTotalShardsCount();

        var provider = provider("shard", "some-host", limit);

        Set<Integer> shardIds = new HashSet<>();
        for (int i = 0; i < count; i++) {
            shardIds.add(provider.shardId(Labels.empty()));
        }

        // fill at least 3/4 of all shards
        Assert.assertTrue(shardIds.size() > (3.0 / 4.0) * stockpile.getTotalShardsCount());

        for (int shardId : shardIds) {
            int actualSize = stats.get(shardId);

            // do not exceed limit x2
            Assert.assertTrue(actualSize < 2 * limit);
        }
    }

    @Test
    public void exceedLimitInAllShards() {
        final int limit = 10;
        final int count = limit * 1_000_000;

        var provider = provider("shard", "some-host", limit);

        Set<Integer> shardIds = new HashSet<>();
        for (int i = 0; i < count; i++) {
            shardIds.add(provider.shardId(Labels.empty()));
        }

        // we fill all shards
        Assert.assertEquals(stockpile.getTotalShardsCount(), shardIds.size());

        double ideal = (double) count / (double) stockpile.getTotalShardsCount();
        IntSummaryStatistics actual = stats.stream().summaryStatistics();

        System.out.println("avg: " + actual.getAverage());
        System.out.println("min: " + actual.getMin());
        System.out.println("max: " + actual.getMax());
        System.out.println("ideal: " + ideal);

        double avgDiff = Math.abs(ideal - actual.getAverage()) / ideal;
        Assert.assertTrue(avgDiff < 0.001);  // less than 0.1%

        double minDiff = Math.abs(ideal - actual.getMin()) / ideal;
        Assert.assertTrue(minDiff < 0.1);   // less than 10%

        double maxDiff = Math.abs(ideal - actual.getMax()) / ideal;
        Assert.assertTrue(maxDiff < 0.1);   // less than 10%
    }

    @Test
    public void restoreStreamAfterRestart() {
        final int limit = 3;
        final String shardId = "some-shard";
        final String hostname = "some-host";

        // log of seen shard ids
        List<Integer> log = new ArrayList<>();

        // first run
        {
            var provider = provider(shardId, hostname, limit);
            for (int i = 0; i < limit + 1; i++) {
                log.add(provider.shardId(Labels.empty()));
            }
        }

        // second run
        {
            var provider = provider(shardId, hostname, limit);
            for (int i = 0; i < limit + 1; i++) {
                log.add(provider.shardId(Labels.empty()));
            }
        }

        //    limit
        // /----------\
        // +---+---+---+---+---+---+---+---+
        // | 1 | 1 | 1 | 2 | 1 | 2 | 2 | 3 |
        // +---+---+---+---+---+---+---+---+
        //                 ^
        //              restart

        System.out.println(log);
        Assert.assertEquals(2 * limit + 2, log.size());

        int i = 0;
        int firstId = log.get(i++);

        // first N values are the same
        for (int j = 0; j < limit-1; i++, j++) {
            Assert.assertEquals(firstId, log.get(i).intValue());
        }

        // after exceeding the limit we get another value
        int secondId = log.get(i++);
        Assert.assertNotEquals(firstId, secondId);

        // after restart we tried first value but understand that
        // it exceeds its limit and go further
        int firstIdAfterRestart = log.get(i++);
        Assert.assertEquals(firstId, firstIdAfterRestart);

        // so next N ids must be the same as secondId
        for (int j = 0; j < limit-1; i++, j++) {
            Assert.assertEquals(secondId, log.get(i).intValue());
        }

        // after exceeding the limit we get another value
        int thirdId = log.get(i);
        Assert.assertNotEquals(firstId, thirdId);
        Assert.assertNotEquals(secondId, thirdId);
    }

    @Test
    public void differentStreams() {
        final int limit = 1;
        final int count = 1_000;

        var s1h1 = provider("s1", "h1", limit);
        var s1h2 = provider("s1", "h2", limit);
        var s2h1 = provider("s2", "h1", limit);
        var s2h2 = provider("s2", "h2", limit);

        int intersections = 0;

        for (int i = 0; i < count; i++) {
            int id1 = s1h1.shardId(Labels.empty());
            int id2 = s1h2.shardId(Labels.empty());
            int id3 = s2h1.shardId(Labels.empty());
            int id4 = s2h2.shardId(Labels.empty());

            if (id1 == id2) intersections++;
            if (id1 == id3) intersections++;
            if (id1 == id4) intersections++;

            if (id2 == id3) intersections++;
            if (id2 == id4) intersections++;

            if (id3 == id4) intersections++;
        }

        System.out.println("intersections: " + intersections);
        Assert.assertTrue((double) intersections / count < 0.01); // less than 1%
    }

    @Test
    public void realHeavyShard() {
        String shardId = "yt_hahn_node_rpc_9e5bb418";
        final int limit = 16_000;

        var aggr1 = provider(shardId, "Aggr", limit);
        var aggr2 = provider(shardId, "Aggr_DC_SAS", limit);

        for (int i = 1; i <= 10_000; i++) {
            String hostname = String.format("n%04d-sas", i);
            var host = provider(shardId, hostname, limit);

            // let assume each host produce 5k metrics and we have to aggregate each of them twice
            for (int metric = 0; metric < 5_000; metric++) {
                host.shardId(Labels.empty());
                aggr1.shardId(Labels.empty());
                aggr2.shardId(Labels.empty());
            }
        }

        IntSummaryStatistics actual = stats.stream().summaryStatistics();
        Assert.assertEquals(150_000_000, actual.getSum()); // total metrics

        double ideal = (double) actual.getSum() / (double) stockpile.getTotalShardsCount();

        System.out.println("avg: " + actual.getAverage());
        System.out.println("min: " + actual.getMin());
        System.out.println("max: " + actual.getMax());
        System.out.println("ideal: " + ideal);

        double avgDiff = Math.abs(ideal - actual.getAverage()) / ideal;
        Assert.assertTrue(avgDiff < 0.0001);  // less than 0.01%

        double minDiff = Math.abs(ideal - actual.getMin()) / ideal;
        Assert.assertTrue(minDiff < 0.1);   // less than 10%

        double maxDiff = Math.abs(ideal - actual.getMax()) / ideal;
        Assert.assertTrue(maxDiff < 0.1);   // less than 10%
    }

    @Test
    public void changeNumberOfShards() {
        stockpile.setShardCount(32);
        int limit = 5;
        var provider = provider("myShardId", "myHostId", limit);

        String key = "name";
        for (int i = 1; i <= 100_000; i++) {
            var labels = Labels.of(key, String.format("n-%06d-sas", i));
            int shardId = provider.shardId(labels);
            StockpileShardId.validate(shardId);
            assertThat(shardId, greaterThanOrEqualTo(1));
            assertThat(shardId, lessThanOrEqualTo(32));
        }

        stockpile.setShardCount(128);
        for (int i = 1; i <= 100_000; i++) {
            var labels = Labels.of(key, String.format("n-%06d-sas", i));

            int shardId = provider.shardId(labels);
            StockpileShardId.validate(shardId);
            assertThat(shardId, greaterThanOrEqualTo(1));
            assertThat(shardId, lessThanOrEqualTo(128));
        }

        IntSummaryStatistics actual = stats.stream().summaryStatistics();
        assertEquals(200_000, actual.getSum()); // total metrics
        assertNotEquals(0, actual.getMin());

        double ideal = (double) actual.getSum() / (double) stockpile.getTotalShardsCount();

        System.out.println("avg: " + actual.getAverage());
        System.out.println("min: " + actual.getMin());
        System.out.println("max: " + actual.getMax());
        System.out.println("ideal: " + ideal);
    }

    private StockpileMetricIdProviderStatsAware provider(String shardId, String host, int limit) {
        return new StockpileMetricIdProviderStatsAware(shardId.hashCode(), host, limit, stats, stockpile);
    }
}
