package ru.yandex.solomon.memory.layout;

import java.nio.ByteBuffer;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import javax.annotation.Nullable;

import com.google.protobuf.ByteString;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledHeapByteBuf;
import it.unimi.dsi.fastutil.HashCommon;
import it.unimi.dsi.fastutil.ints.Int2LongMap;
import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
import org.openjdk.jol.info.ClassLayout;
import org.openjdk.jol.util.MathUtil;
import org.openjdk.jol.vm.VM;
import org.openjdk.jol.vm.VirtualMachine;

/**
 * @author Stepan Koltsov
 */
public class MemoryCounter {
    private static final VirtualMachine vm = VM.current();

    // TODO: 4 with compressed oops
    public static final int OBJECT_POINTER_SIZE = (int) vm.sizeOfField("oop");
    public static final long INT_SIZE = vm.sizeOfField("int");
    public static final long LONG_SIZE = vm.sizeOfField("long");
    public static final long BYTE_SIZE = vm.sizeOfField("byte");
    public static final long BOOLEAN_SIZE = vm.sizeOfField("boolean");
    public static final long CHAR_SIZE = vm.sizeOfField("char");
    public static final long DOUBLE_SIZE = vm.sizeOfField("double");

    public static final int OBJECT_HEADER_SIZE = vm.objectHeaderSize();

    public static final int ARRAY_HEADER_SIZE = vm.arrayHeaderSize();

    private static final long HASH_MAP_SELF_SIZE = objectSelfSizeLayout(HashMap.class);
    private static final long ENUM_MAP_SELF_SIZE = objectSelfSizeLayout(EnumMap.class);

    /** See structure of java.util.HashMap.Node. */
    public static final int HASH_MAP_NODE_SIZE = OBJECT_HEADER_SIZE + 4 + 3 * OBJECT_POINTER_SIZE;

    public static long arrayObjectSize(byte[] array) {
        if (array == null) {
            return 0;
        }

        return arraySize(array.length, BYTE_SIZE);
    }

    public static long arrayObjectSize(boolean[] array) {
        if (array == null) {
            return 0;
        }

        return arraySize(array.length, BOOLEAN_SIZE);
    }

    public static long arrayObjectSize(int[] array) {
        if (array == null) {
            return 0;
        }

        return arraySize(array.length, INT_SIZE);
    }

    public static long arrayObjectSize(long[] array) {
        if (array == null) {
            return 0;
        }

        return arraySize(array.length, LONG_SIZE);
    }

    public static long arrayObjectSize(double[] array) {
        if (array == null) {
            return 0;
        }

        return arraySize(array.length, DOUBLE_SIZE);
    }

    public static long arrayObjectSize(Object[] array) {
        if (array == null) {
            return 0;
        }

        return arraySize(array.length, OBJECT_POINTER_SIZE);
    }

    public static long arrayObjectSizeWithContent(MemMeasurable[] array) {
        if (array == null) {
            return 0;
        }

        long size = arraySize(array.length, OBJECT_POINTER_SIZE);
        for (MemMeasurable mem : array) {
            size += mem.memorySizeIncludingSelf();
        }
        return size;
    }

    public static long arraySize(long length, long scale) {
        long instanceSize = ARRAY_HEADER_SIZE + length * scale;
        return MathUtil.align(instanceSize, vm.objectAlignment());
    }

    public static <A extends MemMeasurable> long listDataSizeWithContent(List<A> list) {
        long size = MemoryCounter.arraySize(list.size(), MemoryCounter.OBJECT_POINTER_SIZE);
        for (var item : list) {
            size += item.memorySizeIncludingSelf();
        }
        return size;
    }

    public static long objectSelfSizeLayout(Class<?> clazz) {
        return ClassLayout.parseClass(clazz).instanceSize();
    }

    private static final long Long2ObjectOpenHashMap_SELF_SIZE = objectSelfSizeLayout(Long2ObjectOpenHashMap.class);
    private static final long Int2ObjectOpenHashMap_SELF_SIZE = objectSelfSizeLayout(Int2ObjectOpenHashMap.class);
    private static final long Int2LongOpenHashMap_SELF_SIZE = objectSelfSizeLayout(Int2LongOpenHashMap.class);
    private static final long Object2IntOpenHashMap_SELF_SIZE = objectSelfSizeLayout(Object2IntOpenHashMap.class);
    private static final long Object2ObjectOpenHashMap_SELF_SIZE = objectSelfSizeLayout(Object2ObjectOpenHashMap.class);
    private static final long LongOpenHashSet_SELF_SIZE = objectSelfSizeLayout(LongOpenHashSet.class);
    private static final long ByteBuffer_SELF_SIZE = objectSelfSizeLayout(ByteBuffer.class);
    private static final long ByteBuf_SELF_SIZE = objectSelfSizeLayout(UnpooledHeapByteBuf.class);
    private static final long ByteString_SELF_SIZE = objectSelfSizeLayout(ByteString.class) + 8;
    public static final long CompletableFuture_SELF_SIZE = objectSelfSizeLayout(CompletableFuture.class);
    private static final long String_SELF_SIZE = objectSelfSizeLayout(String.class);

    /**
     * @return memory size occupied by ByteBuffer
     *
     * NOTE: this function does not take care of buffers which share their contents.
     */
    public static long byteBufferSize(ByteBuffer buf) {
        long size = ByteBuffer_SELF_SIZE + (long) buf.capacity();
        if (buf.isDirect()) {
            // DirectByteBuffer additionally store reference to an attached object and cleaner
            return size + 2 * OBJECT_POINTER_SIZE;
        }
        return size;
    }

    public static long byteBufSize(ByteBuf buf) {
        long size = ByteBuf_SELF_SIZE + (long) buf.capacity();
        if (buf.isDirect()) {
            // DirectByteBuffer additionally store reference to an attached object and cleaner
            return size + 2 * OBJECT_POINTER_SIZE;
        }
        return size;
    }

    public static long byteStringSize(ByteString bytes) {
        return ByteString_SELF_SIZE + arraySize(bytes.size(), BYTE_SIZE);
    }

    public static long stringSize(String s) {
        return String_SELF_SIZE + arraySize(s.length(), BYTE_SIZE);
    }

    /**
     * Without values.
     */
    public static long long2ObjectOpenHashMapSize(@Nullable Long2ObjectOpenHashMap<?> map) {
        if (map == null) {
            return 0;
        }

        long arraySize;
        if (map.size() <= 16) {
            arraySize = HashCommon.arraySize(16, 0.75F) + 1;
        } else {
            arraySize = HashCommon.arraySize(map.size(), 0.75F) + 1;
        }

        return Long2ObjectOpenHashMap_SELF_SIZE
            + arraySize(arraySize, LONG_SIZE)
            + arraySize(arraySize, OBJECT_POINTER_SIZE);
    }

    public static long longOpenHashSetSize(@Nullable LongSet set) {
        if (set == null) {
            return 0;
        }

        long arraySize = HashCommon.arraySize(set.size(), 0.75f) + 1;
        return LongOpenHashSet_SELF_SIZE
            + arraySize(arraySize, LONG_SIZE);
    }

    public static <A extends MemMeasurable> long long2ObjectOpenHashMapSizeWithContent(@Nullable Long2ObjectOpenHashMap<A> map) {
        if (map == null) {
            return 0;
        }

        long valuesSize = 0;
        for (A a : map.values()) {
            valuesSize += MemMeasurable.memorySizeOfNullable(a);
        }
        return long2ObjectOpenHashMapSize(map) + valuesSize;
    }

    public static <K extends Enum<K>, A extends MemMeasurable> long enumMapSizeWithContent(Class<K> clazz, @Nullable EnumMap<K, A> map) {
        if (map == null) {
            return 0;
        }

        long size = ENUM_MAP_SELF_SIZE;
        size += arraySize(clazz.getEnumConstants().length, INT_SIZE);
        for (A value : map.values()) {
            size += value.memorySizeIncludingSelf();
        }
        return size;
    }

    private static long hashMapSizeWithoutContent(@Nullable Map<?, ?> map, int keySize, int valueSize) {
        if (map == null) {
            return 0;
        }

        long arraySize = HashCommon.nextPowerOfTwo(map.size());
        return ARRAY_HEADER_SIZE + arraySize * keySize + ARRAY_HEADER_SIZE + arraySize * valueSize;
    }

    /**
     * Without values.
     */
    public static long int2ObjectOpenHashMapSize(@Nullable Int2ObjectOpenHashMap<?> map) {
        return hashMapSizeWithoutContent(map, Integer.BYTES, OBJECT_POINTER_SIZE) + Int2ObjectOpenHashMap_SELF_SIZE;
    }

    /**
     * Without values.
     */
    public static long object2IntOpenHashMapSize(@Nullable Object2IntOpenHashMap<?> map) {
        return hashMapSizeWithoutContent(map, OBJECT_POINTER_SIZE, Integer.BYTES) + Object2IntOpenHashMap_SELF_SIZE;
    }

    /**
     * Without values.
     */
    public static long object2ObjectOpenHashMapSize(@Nullable Object2ObjectOpenHashMap<?, ?> map) {
        return hashMapSizeWithoutContent(map, OBJECT_POINTER_SIZE, OBJECT_POINTER_SIZE) + Object2ObjectOpenHashMap_SELF_SIZE;
    }

    public static long int2LongMapSize(@Nullable Int2LongMap map) {
        if (map == null) {
            return 0;
        }

        long arraySize = HashCommon.arraySize(Math.max(map.size(), 16), 0.75f) + 1;
        long result = Int2LongOpenHashMap_SELF_SIZE;
        result += arraySize(arraySize, INT_SIZE);
        result += arraySize(arraySize, LONG_SIZE);
        return result;
    }

    /**
     * Without values.
     */
    public static long hashMapSize(Map<?, ?> map) {
        return hashMapSize(map.size());
    }

    /**
     * Without values.
     */
    public static long hashMapSize(int size) {
        return HASH_MAP_SELF_SIZE + (long) HASH_MAP_NODE_SIZE * size;
    }

    public static long hashSetSize(Set<?> set) {
        return HASH_MAP_SELF_SIZE + HASH_MAP_NODE_SIZE * set.size();
    }

    public static <A extends MemMeasurable> long int2ObjectOpenHashMapSizeWithContent(@Nullable Int2ObjectOpenHashMap<A> map) {
        if (map == null) {
            return 0;
        }

        long valuesSize = 0;
        for (A value : map.values()) {
            valuesSize += MemMeasurable.memorySizeOfNullable(value);
        }
        return int2ObjectOpenHashMapSize(map) + valuesSize;
    }

}
