package ru.yandex.stockpile.server.shard.cache;

import java.util.EnumSet;
import java.util.List;
import java.util.function.ToIntFunction;
import java.util.function.ToLongBiFunction;
import java.util.function.ToLongFunction;
import java.util.stream.Collectors;

import javax.annotation.ParametersAreNonnullByDefault;

import ru.yandex.solomon.util.collection.enums.EnumMapToLong;

/**
 * @author Stepan Koltsov
 */
@ParametersAreNonnullByDefault
public class CacheWeights {

    private enum SingletonEnum { S }

    static <S> void updateWeightsOneGroup(
        List<S> shards,
        ToLongFunction<S> shardWeight,
        long totalCacheSize,
        long[] cacheSizes)
    {
        updateWeights(
            shards,
            SingletonEnum.class,
            s -> 1,
            (s, singletonEnum) -> shardWeight.applyAsLong(s),
            totalCacheSize,
            cacheSizes);
    }

    @SuppressWarnings("NarrowingCompoundAssignment")
    public static <S, W extends Enum<W>> void updateWeights(
        List<S> shards,
        Class<W> weightKeyClass,
        ToIntFunction<W> keyWeight,
        ToLongBiFunction<S, W> shardWeight,
        long totalCacheSize,
        long[] cacheSizes)
    {
        if (cacheSizes.length != shards.size()) {
            throw new IllegalArgumentException();
        }

        if (shards.isEmpty()) {
            return;
        }

        List<EnumMapToLong<W>> shardAddendums = shards.stream()
            .map(s -> shardAddemdums(s, weightKeyClass, shardWeight))
            .collect(Collectors.toList());

        EnumMapToLong<W> sum = EnumMapToLong.sum(shardAddendums, weightKeyClass);

        EnumSet<W> nonZeroAddendums = EnumSet.noneOf(weightKeyClass);
        int nonZeroWeightSum = 0;
        for (W addendum : weightKeyClass.getEnumConstants()) {
            if (sum.get(addendum) != 0) {
                nonZeroAddendums.add(addendum);
                nonZeroWeightSum += keyWeight.applyAsInt(addendum);
            }
        }

        if (nonZeroWeightSum == 0) {
            return;
        }

        for (int i = 0; i < shards.size(); ++i) {
            EnumMapToLong<W> shardAddendum = shardAddendums.get(i);

            long shardCacheSizeMultipliedByWeight = 0;
            for (W addendum : nonZeroAddendums) {
                int addendumWeight = keyWeight.applyAsInt(addendum);

                long shardWeightForAddendum = shardAddendum.get(addendum);
                long totalForAddendum = sum.get(addendum);

                // We get long overflows for some addendums, e.g. StockpileCacheManager.WeightAddendum.SIZE
                // We'll do fine with double precision
                double shardPartInAddendum = ((double) shardWeightForAddendum) / totalForAddendum;

                shardCacheSizeMultipliedByWeight += totalCacheSize * addendumWeight * shardPartInAddendum;
            }

            long shardCacheSize = shardCacheSizeMultipliedByWeight / nonZeroWeightSum;
            cacheSizes[i] = shardCacheSize;
        }
    }

    private static <S, W extends Enum<W>> EnumMapToLong<W> shardAddemdums(
        S shard, Class<W> weightClass, ToLongBiFunction<S, W> getWeight)
    {
        return new EnumMapToLong<>(weightClass, e -> getWeight.applyAsLong(shard, e));
    }


}
