package ru.yandex.chemodan.ydb.dao;

import com.yandex.ydb.table.values.PrimitiveValue;
import org.junit.Test;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.misc.random.Random2;
import ru.yandex.misc.test.Assert;

/**
 * @author yashunsky
 */
public class YdbUtilsTest {
    @Test
    public void testTimestampPartitions() {
        ListF<String> hexes = YdbUtils.getTimestampPartitions(16).map(PrimitiveValue::getUint32).map(Long::toHexString);
        ListF<String> expected = Cf.list(
                "8000000", "18000000", "28000000", "38000000",
                "48000000", "58000000", "68000000", "78000000",
                "88000000", "98000000", "a8000000", "b8000000",
                "c8000000", "d8000000", "e8000000", "f8000000"
        );
        Assert.equals(expected, hexes);
    }

    @Test
    public void testUniform() {
        int shardsCount = 128;
        int recordsCount = 100000;

        ListF<String> nonUniformData = Cf.range(0, recordsCount).map(id -> "Some nonuniform data " + id);
        ListF<String> randomData = Cf.range(0, recordsCount).map(id -> Random2.R.nextString(100));

        assertUniform(shardsCount, nonUniformData, 0.11f, "non uniform");
        assertUniform(shardsCount, randomData, 0.11f, "random");
     }

    private void assertUniform(int shardsCount, ListF<String> data, float allowedNonUniform, String message) {
        int recordsCount = data.length();
        int perfectLoad = recordsCount / shardsCount;
        int shardSize = Integer.MAX_VALUE / shardsCount - Integer.MIN_VALUE / shardsCount;
        MapF<Integer, Integer> sharded = data
                .map(YdbUtils::getHashValue)
                .map(hash -> getShard(hash, shardSize))
                .groupBy(shard -> shard)
                .mapValues(ListF::length);

        ListF<Integer> shardsLoad =
                Cf.range(0, shardsCount).zipWith(shardId -> sharded.getOrElse(shardId, 0)).get2();

        int minLoad = shardsLoad.min();
        int maxLoad = shardsLoad.max();

        System.out.println(minLoad + " " + perfectLoad + " " + maxLoad);

        Assert.gt(minLoad, (int) (perfectLoad * (1 - allowedNonUniform)), message);
        Assert.lt(maxLoad, (int) (perfectLoad * (1 + allowedNonUniform)), message);
    }

    private int getShard(int hash, int shardSize) {
        return (int) (Integer.toUnsignedLong(hash) / shardSize);
    }
}
