package ru.yandex.metabase.client.impl;

import java.util.BitSet;
import java.util.stream.IntStream;

import com.google.common.collect.Interner;
import com.google.common.collect.Interners;

import ru.yandex.metabase.client.MetabasePartitions;
import ru.yandex.solomon.labels.shard.ShardKey;

/**
 * @author Egor Litvinenko
 */
public class MyInterners {

    private static final Interner<String> PROJECT_CLUSTER_SERVICE_INTERNER = Interners.newWeakInterner();
    private static final Interner<ShardKey> SHARDK_KEY_INTERNER = Interners.newWeakInterner();
    private static final Interner<String> FQDN_INTERNER = Interners.newWeakInterner();
    private static final BitSetInterner BIT_SET_INTERNER = new BitSetInterner();

    public static Interner<String> pcs() {
        return PROJECT_CLUSTER_SERVICE_INTERNER;
    }

    public static Interner<String> fqdn() {
        return FQDN_INTERNER;
    }

    public static Interner<ShardKey> shardKey() {
        return SHARDK_KEY_INTERNER;
    }

    public static BitSetInterner bitset() {
        return BIT_SET_INTERNER;
    }

    /**
     * Optimized for partitions, which are power of two and in range [1, 1024];
     * Otherwise return bitset as is.
     */
    static class BitSetInterner {

        private static final int MAX_PARTITION_TWO_POWER = MetabasePartitions.MAX_PARTITIONS_POWER + 2; // with safety
        private static final int MAX_PARTITIONS = MetabasePartitions.MAX_PARTITIONS;
        private static final BitSet[] ALL_IS_ONE = new BitSet[MAX_PARTITION_TWO_POWER + 1];
        static {
            IntStream.rangeClosed(0, MAX_PARTITION_TWO_POWER).forEach(i -> {
                int size = 1 << i;
                ALL_IS_ONE[i] = new BitSet(size);
                for (int j = 0; j < size; j++) {
                    ALL_IS_ONE[i].set(j);
                }
            });
        }

        private static final BitSet[] ALL_IS_ZERO = new BitSet[MAX_PARTITION_TWO_POWER + 1];
        static {
            IntStream.rangeClosed(0, MAX_PARTITION_TWO_POWER).forEach(i -> {
                int size = 1 << i;
                ALL_IS_ZERO[i] = new BitSet(size);
                for (int j = 0; j < size; j++) {
                    ALL_IS_ZERO[i].set(j, false);
                }
            });
        }

        private static final BitSet[] ONLY_ONE = new BitSet[MAX_PARTITION_TWO_POWER + 1];

        static {
            IntStream.rangeClosed(0, MAX_PARTITION_TWO_POWER).forEach(i -> {
                int size = 1 << i;
                ONLY_ONE[i] = new BitSet(size);
                ONLY_ONE[i].set(size - 1);
            });
        }

        public BitSet intern(BitSet sample, int totalPartitions) {
            if (totalPartitions > MAX_PARTITIONS) {
                return sample;
            }
            final int cardinality = sample.cardinality();
            if (cardinality == totalPartitions) {
                if (totalPartitions == 1) {
                    return ALL_IS_ONE[0];
                }
                var power = calcSmallPowerOfTwo(totalPartitions);
                if (power > -1) {
                    return ALL_IS_ONE[power];
                }
            }
            if (cardinality == 1) {
                if (sample.nextSetBit(0) == totalPartitions - 1) {
                    var power = calcSmallPowerOfTwo(totalPartitions);
                    if (power > -1) {
                        return ONLY_ONE[power];
                    } else {
                        return sample;
                    }
                }
            }
            if (cardinality == 0) {
                if (totalPartitions == 1) {
                    return ALL_IS_ZERO[0];
                }
                var power = calcSmallPowerOfTwo(totalPartitions);
                if (power > -1) {
                    return ALL_IS_ZERO[power];
                } else {
                    return sample;
                }
            }
            return sample;
        }

        static BitSet ofOne(int partitionId) {
            if (partitionId < 1) {
                return ALL_IS_ONE[0];
            }
            if (partitionId < MAX_PARTITIONS) {
                final int power = calcSmallPowerOfTwo(partitionId + 1);
                if (power > -1) {
                    return ONLY_ONE[power];
                } else {
                    final BitSet bitSet = new BitSet(partitionId + 1);
                    bitSet.set(partitionId);
                    return bitSet;
                }
            }
            throw new IllegalArgumentException("partitionId " + partitionId + " >= max partitions " + MAX_PARTITIONS);
        }

        static BitSet ofAll(int totalPartitions) {
            if (totalPartitions <= 1) {
                return ALL_IS_ONE[0];
            }
            if (totalPartitions <= MAX_PARTITIONS) {
                final int power = calcSmallPowerOfTwo(totalPartitions);
                if (power > -1) {
                    return ALL_IS_ONE[power];
                } else {
                    final BitSet bitSet = new BitSet(totalPartitions);
                    for (int i = 0; i < totalPartitions; i++) {
                        bitSet.set(i);
                    }
                    return bitSet;
                }
            }
            throw new IllegalArgumentException("totalPartitions " + totalPartitions + " is bigger then max partitions" +
                    " " + MAX_PARTITIONS);
        }

        static BitSet ofAllZeros(int totalPartitions) {
            if (totalPartitions <= 1) {
                return ALL_IS_ZERO[0];
            }
            if (totalPartitions <= MAX_PARTITIONS) {
                final int power = calcSmallPowerOfTwo(totalPartitions);
                if (power > -1) {
                    return ALL_IS_ZERO[power];
                } else {
                    final BitSet bitSet = new BitSet(totalPartitions);
                    for (int i = 0; i < totalPartitions; i++) {
                        bitSet.set(i, false);
                    }
                    return bitSet;
                }
            }
            throw new IllegalArgumentException("totalPartitions " + totalPartitions + " is bigger then max partitions" +
                    " " + MAX_PARTITIONS);
        }

        private static final int[] PARTITION_POWERS_OF_TWO = IntStream.rangeClosed(1, MetabasePartitions.MAX_PARTITIONS_POWER + 2).map(x -> 1 << x).toArray();

        static int calcSmallPowerOfTwo(int x) {
            for (int i = 0; i < PARTITION_POWERS_OF_TWO.length; i++) {
                if (PARTITION_POWERS_OF_TWO[i] == x) {
                    return i + 1;
                }
            }
            return -1;
        }

    }
}
