package ru.yandex.msearch.util;

import java.lang.reflect.Field;
import java.util.Arrays;

import sun.misc.Unsafe;

public class AtomicBitsArray {
    private static final int LONG_BITS = 64;
    private static final int INTEGER_BITS = 32;

    private static final boolean DEBUG = false;
   // setup to use Unsafe.compareAndSwapInt for updates
    private final long[] array;
    private final int length;
    private final int nbits;
    private final long mask;
    private final int elementsPerLong;
    private final int elementsPerLongShift;

    private static final Unsafe unsafe;
    static {
        try {
            Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
            theUnsafe.setAccessible(true);
            unsafe = (Unsafe) theUnsafe.get(null);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static final int base = unsafe.arrayBaseOffset(long[].class);
    private static final int scale = unsafe.arrayIndexScale(long[].class);

    private long rawIndex(int i) {
        if (i < 0 || i >= array.length)
            throw new IndexOutOfBoundsException("index " + i);
        return base + i * scale;
    }

    public AtomicBitsArray(int length, int nbits) {
        if (!isPowerOfTwo(nbits)) {
            throw new IllegalArgumentException(
                "nbits must be power of two, was: " + nbits);
        }
        if (nbits > LONG_BITS) {
            throw new IllegalArgumentException(
                "nbits must be in range (1 - 32), was: " + nbits);
        }
        this.length = length;
        this.nbits = nbits;
        this.mask = ((long)1 << nbits) - 1;

        this.elementsPerLong = LONG_BITS / nbits;
        this.elementsPerLongShift = log2(elementsPerLong);
//        this.elementShift 
        int numLongs = (length / elementsPerLong) + 1;
        array = new long[numLongs];
        // must perform at least one volatile write to conform to JMM
        if (length > 0) {
            unsafe.putLongVolatile(array, rawIndex(0), 0);
        }
    }

    public int maxValuePerItem() {
        return MAX_VALUE(nbits);
    }

    private static final int MAX_VALUE(int bits) {
        return (1 << bits) - 1;
    }

    private static final int log2(int n) {
        return INTEGER_BITS - 1 - Integer.numberOfLeadingZeros(n);
    }

    private static final boolean isPowerOfTwo(int i) {
        return (i & (i - 1)) == 0;
    }

    /**
     * Returns the length of the array.
     *
     * @return the length of the array
     */
    public final int length() {
        return length;
    }

    public final int sizeInBytes() {
        return array.length * 8;
    }

    public final int numBits() {
        return nbits;
    }

    /**
     * Gets the current value at position {@code i}.
     *
     * @param i the index
     * @return the current value
     */
    public final int get(int i) {
        final int longIdx = i >> elementsPerLongShift;
        final int bitsShift = (int)(i & mask) * nbits;
        final long longValue = unsafe.getLongVolatile(array, rawIndex(longIdx));
        if (DEBUG) {
            System.out.println("Get("+i+"): longIdx=" + longIdx + ", bitsShift=" + bitsShift + ", longValue=" + Long.toHexString(longValue));
        }
        return (int)((longValue >> bitsShift) & mask);
    }

    /**
     * Sets the element at position {@code i} to the given value.
     *
     * @param i the index
     * @param newValue the new value
     */
    public final void set(int i, int newValue) {
        final int longIdx = i >> elementsPerLongShift;
        final int bitsShift = (int)(i & mask) * nbits;
        final long shiftedValue = (long)(newValue & mask) << bitsShift;
        final long clearMask = ~((long)mask << bitsShift);
        if (DEBUG) {
            System.out.println("Set("+i+"): longIdx=" + longIdx + ", bitsShift=" + bitsShift + ", shiftedValue=" + Long.toHexString(shiftedValue) + ", clearMask=" + Long.toHexString(clearMask));
        }
        while (true) {
            final long longValue =
                unsafe.getLongVolatile(array, rawIndex(longIdx));
            final long newLongValue = (longValue & clearMask) | shiftedValue;
            if (unsafe.compareAndSwapLong(array, rawIndex(longIdx),
                longValue, newLongValue))
            {
                return;
            }
        }
    }

    /**
     * Atomically sets the element at position {@code i} to the given
     * value and returns the old value.
     *
     * @param i the index
     * @param newValue the new value
     * @return the previous value
     */
    public final int getAndSet(int i, int newValue) {
        while (true) {
            int current = get(i);
            if (compareAndSet(i, current, newValue))
                return current;
        }
    }

    /**
     * Atomically sets the element at position {@code i} to the given
     * updated value if the current value {@code ==} the expected value.
     *
     * @param i the index
     * @param expect the expected value
     * @param update the new value
     * @return true if successful. False return indicates that
     * the actual value was not equal to the expected value.
     */
    public final boolean compareAndSet(int i, int expect, int update) {
        final int longIdx = i >> elementsPerLongShift;
        final int bitsShift = (int)(i & mask) * nbits;
        final long shiftedValue = (long)(update & mask) << bitsShift;
        final long clearMask = ~((long)mask << bitsShift);
        final long longValue =
            unsafe.getLongVolatile(array, rawIndex(longIdx));
        final long newLongValue = (longValue & clearMask) | shiftedValue;
        return unsafe.compareAndSwapLong(array, rawIndex(longIdx),
                longValue, newLongValue);
    }

    /**
     * Atomically sets the element at position {@code i} to the given
     * updated value if the current value {@code ==} the expected value.
     *
     * <p>May <a href="package-summary.html#Spurious">fail spuriously</a>
     * and does not provide ordering guarantees, so is only rarely an
     * appropriate alternative to {@code compareAndSet}.
     *
     * @param i the index
     * @param expect the expected value
     * @param update the new value
     * @return true if successful.
     */
    public final boolean weakCompareAndSet(int i, int expect, int update) {
        return compareAndSet(i, expect, update);
    }

    /**
     * Atomically increments by one the element at index {@code i}.
     *
     * @param i the index
     * @return the previous value
     */
    public final int getAndIncrement(int i) {
        while (true) {
            int current = get(i);
            int next = current + 1;
            if (compareAndSet(i, current, next))
                return current;
        }
    }

    /**
     * Atomically decrements by one the element at index {@code i}.
     *
     * @param i the index
     * @return the previous value
     */
    public final int getAndDecrement(int i) {
        while (true) {
            int current = get(i);
            int next = current - 1;
            if (compareAndSet(i, current, next))
                return current;
        }
    }

    /**
     * Atomically adds the given value to the element at index {@code i}.
     *
     * @param i the index
     * @param delta the value to add
     * @return the previous value
     */
    public final int getAndAdd(int i, int delta) {
        while (true) {
            int current = get(i);
            int next = current + delta;
            if (compareAndSet(i, current, next))
                return current;
        }
    }

    /**
     * Atomically increments by one the element at index {@code i}.
     *
     * @param i the index
     * @return the updated value
     */
    public final int incrementAndGet(int i) {
        while (true) {
            int current = get(i);
            int next = current + 1;
            if (compareAndSet(i, current, next))
                return next;
        }
    }

    /**
     * Atomically decrements by one the element at index {@code i}.
     *
     * @param i the index
     * @return the updated value
     */
    public final int decrementAndGet(int i) {
        while (true) {
            int current = get(i);
            int next = current - 1;
            if (compareAndSet(i, current, next))
                return next;
        }
    }

    /**
     * Atomically adds the given value to the element at index {@code i}.
     *
     * @param i the index
     * @param delta the value to add
     * @return the updated value
     */
    public final int addAndGet(int i, int delta) {
        while (true) {
            int current = get(i);
            int next = current + delta;
            if (compareAndSet(i, current, next))
                return next;
        }
    }

    /**
     * Returns the String representation of the current values of array.
     * @return the String representation of the current values of array.
     */
    public String toString() {
        if (array.length > 0) // force volatile read
            get(0);
        return Arrays.toString(array);
    }

    public static void main(String[] args) {
        System.out.println("Running test");
        int length = 100;
//        DEBUG = true;
        for (int nbits = 1; nbits <=32; nbits <<= 1) {
            System.out.println("NumBits: " + nbits);
            AtomicBitsArray aba = new AtomicBitsArray(length, nbits);
            //fill
            for (int i = 0; i < length; i++) {
                aba.set(i, 1);
            }
            //get
            for (int i = 0; i < length; i++) {
                if (aba.get(i) != 1) {
//                    DEBUG = true;
                    throw new RuntimeException("Get failed at pos: " + i + ", value=" + aba.get(i));
//                    System.out.println("
                }
            }
        }
    }

}
