package ru.yandex.stockpile.kikimrKv;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import ru.yandex.kikimr.client.kv.inMem.KikimrKvClientInMem;
import ru.yandex.solomon.ut.ManualClock;
import ru.yandex.solomon.ut.ManualScheduledExecutorService;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;

/**
 * @author Vladimir Gordiychuk
 */
public class KvTabletsMappingTest {

    private ManualClock clock;
    private ManualScheduledExecutorService timer;
    private KikimrKvClientInMem kvClient;

    @Before
    public void setUp() {
        clock = new ManualClock();
        timer = new ManualScheduledExecutorService(1, clock);
        kvClient = new KikimrKvClientInMem();
    }

    @After
    public void tearDown() {
        timer.shutdownNow();
    }

    @Test
    public void initResolve() {
        var volume = "/test";
        kvClient.createKvTablets(volume, 32).join();
        var mapping = mapping(volume);
        mapping.waitForReady();
        assertEquals(32, mapping.getShardCount());

        {
            int[] shards = mapping.getShardIdStream().toArray();
            int[] expected = IntStream.rangeClosed(1, 32).toArray();
            assertArrayEquals(expected, shards);
        }

        for (int shardId = 1; shardId <= 32; shardId++) {
            long tabletId = mapping.getTabletId(shardId);
            assertNotEquals(0, tabletId);
            assertTrue(mapping.hasTabletId(tabletId));
        }
    }

    @Test
    public void forceReload() {
        var volume = "/test";
        kvClient.createKvTablets(volume, 32).join();
        var mapping = mapping(volume);
        mapping.waitForReady();
        assertEquals(32, mapping.getShardCount());

        long[] tablets = new long[32];
        var uniqueTabletIds = new LongOpenHashSet();
        for (int shardId = 1; shardId <= 32; shardId++) {
            long tabletId = mapping.getTabletId(shardId);
            assertNotEquals(0, tabletId);
            tablets[shardId - 1] = tabletId;
            assertTrue(uniqueTabletIds.add(tabletId));
        }

        assertEquals(tablets.length, uniqueTabletIds.size());

        kvClient.alterKvTablets(volume, 64).join();
        assertEquals(32, mapping.getShardCount());
        mapping.forceReload().join();

        assertEquals(64, mapping.getShardCount());
        for (int shardId = 1; shardId <= 32; shardId++) {
            assertEquals(tablets[shardId - 1], mapping.getTabletId(shardId));
        }

        for (int shardId = 33; shardId <= 64; shardId++) {
            assertTrue(uniqueTabletIds.add(mapping.getTabletId(shardId)));
        }

        assertEquals(64, uniqueTabletIds.size());
    }

    @Test
    public void getUnknownShard() {
        var volume = "/test";
        kvClient.createKvTablets(volume, 32).join();
        var mapping = mapping(volume);
        mapping.waitForReady();
        assertEquals(32, mapping.getShardCount());
        assertEquals(KvTabletsMapping.UNKNOWN_TABLET_ID, mapping.getTabletId(42));
    }

    @Test
    public void timebaseReload() throws InterruptedException {
        var volume = "/test";
        kvClient.createKvTablets(volume, 32).join();
        var mapping = mapping(volume);
        mapping.waitForReady();
        assertEquals(32, mapping.getShardCount());

        long[] tablets = new long[32];
        var uniqueTabletIds = new LongOpenHashSet();
        for (int shardId = 1; shardId <= 32; shardId++) {
            long tabletId = mapping.getTabletId(shardId);
            assertNotEquals(0, tabletId);
            tablets[shardId - 1] = tabletId;
            assertTrue(uniqueTabletIds.add(tabletId));
        }

        assertEquals(tablets.length, uniqueTabletIds.size());

        kvClient.alterKvTablets(volume, 64).join();
        assertEquals(32, mapping.getShardCount());
        do {
            clock.passedTime(1, TimeUnit.HOURS);
            TimeUnit.MILLISECONDS.sleep(1L);
        } while (mapping.getShardCount() == 32);

        assertEquals(64, mapping.getShardCount());
        for (int shardId = 1; shardId <= 32; shardId++) {
            assertEquals(tablets[shardId - 1], mapping.getTabletId(shardId));
        }

        for (int shardId = 33; shardId <= 64; shardId++) {
            assertTrue(uniqueTabletIds.add(mapping.getTabletId(shardId)));
        }

        assertEquals(64, uniqueTabletIds.size());
    }

    private KvTabletsMapping mapping(String volumePath) {
        return new KvTabletsMapping(volumePath, kvClient, clock, ForkJoinPool.commonPool(), timer);
    }
}
