package ru.yandex.solomon.coremon.meta.db;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Arrays;
import java.util.concurrent.locks.StampedLock;
import java.util.stream.IntStream;

import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;


/**
 * @author Sergey Polovko
 */
public class ShardIdToAtomicInteger {
    private static final VarHandle ARRAY_ACCESS = MethodHandles.arrayElementVarHandle(int[].class);
    private final StampedLock lock = new StampedLock();
    private int[] array;

    public ShardIdToAtomicInteger() {
        this.array = new int[0];
    }

    public void set(int shardId, int value) {
        long stamp = lock.tryOptimisticRead();
        if (array.length < shardId) {
            expand(shardId);
        }

        volatileSet(shardId, value);
        if (lock.validate(stamp)) {
            return;
        }

        // slow path
        stamp = lock.readLock();
        try {
            volatileSet(shardId, value);
        } finally {
            lock.unlockRead(stamp);
        }
    }

    public void set(Int2IntOpenHashMap map) {
        int maxShardId = maxShardId(map);
        long stamp = lock.writeLock();
        try {
            if (array.length < maxShardId) {
                array = Arrays.copyOf(array, maxShardId);
            }

            var it = map.int2IntEntrySet().fastIterator();
            while (it.hasNext()) {
                var entry = it.next();
                int shardId = entry.getIntKey();
                int value = entry.getIntValue();
                array[idx(shardId)] = value;
            }
        } finally {
            lock.unlock(stamp);
        }
    }

    public int get(int shardId) {
        long stamp = lock.tryOptimisticRead();
        if (array.length < shardId) {
            expand(shardId);
        }

        int value = volatileGet(shardId);
        if (lock.validate(stamp)) {
            return value;
        }

        // slow path
        stamp = lock.readLock();
        try {
            return volatileGet(shardId);
        } finally {
            lock.unlockRead(stamp);
        }
    }

    public int incrementAndGet(int shardId) {
        long stamp = lock.readLock();
        if (array.length < shardId) {
            lock.unlockRead(stamp);
            expand(shardId);
            stamp = lock.readLock();
        }
        try {
            return (int) ARRAY_ACCESS.getAndAdd(array, idx(shardId), 1) + 1;
        } finally {
            lock.unlockRead(stamp);
        }
    }

    public IntStream stream() {
        return Arrays.stream(array());
    }

    public int[] array() {
        long stamp = lock.tryOptimisticRead();
        int[] copy = array.clone();
        if (lock.validate(stamp)) {
            return copy;
        }

        stamp = lock.readLock();
        try {
            return array.clone();
        } finally {
            lock.unlockRead(stamp);
        }
    }

    private void volatileSet(int shardId, int value) {
        ARRAY_ACCESS.setVolatile(array, idx(shardId), value);
    }

    private int volatileGet(int shardId) {
        return (int) ARRAY_ACCESS.getVolatile(array, idx(shardId));
    }

    private void expand(int capacity) {
        long stamp = lock.writeLock();
        try {
            if (array.length < capacity) {
                array = Arrays.copyOf(array, capacity);
            }
        } finally {
            lock.unlockWrite(stamp);
        }
    }

    private int maxShardId(Int2IntOpenHashMap map) {
        int max = 0;
        var it = map.keySet().iterator();
        while (it.hasNext()) {
            max = Math.max(max, it.nextInt());
        }

        return max;
    }

    private int idx(int shardId) {
        return shardId - 1;
    }
}
