package ru.yandex.msearch.util;

import java.io.Closeable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import org.jctools.queues.MpmcArrayQueue;

import ru.yandex.unsafe.NativeMemory2;
import ru.yandex.unsafe.NativeMemory2.NativeMemoryAllocator;

public class NativePositiveLong2IntHashMap implements Closeable {
    public static final long NULL_VALUE = Long.MIN_VALUE;

    private static final NativeMemoryAllocator allocator =
        NativeMemoryAllocator.get("NativePositiveLongHashMap");

    private static final int DEFAULT_SIZE = 1024;
    private static final long SENTINEL = -1;
    private static final long NULL = 0;
    private static final int DELETE_QUEUE_SIZE = 256;//Max concurrent threads

    private static final int NO_REHASH = 0;
    private static final int NEED_REHASH = 1;
    private static final int REHASHING = 2;

    private static final byte FALSE = 0;
    private static final byte TRUE = 1;
    private static final byte RETRY = 2;

    private final ReentrantReadWriteLock arrayLock =
        new ReentrantReadWriteLock();
    private NativeLongArray array;
    private NativeLongArray values;
    private final AtomicInteger sentinels = new AtomicInteger(0);
    private final AtomicInteger arrayFill = new AtomicInteger(0);
    private final AtomicInteger rehashing = new AtomicInteger(NO_REHASH);
    private int halfSize;
    private int maxSentinels;
    private int length;
    private MpmcArrayQueue<Long> deleteQueue =
        new MpmcArrayQueue<>(DELETE_QUEUE_SIZE);

    private static class NativeLongArray {
        private final NativeMemory2 array;
        private final int length;

        public NativeLongArray(final int length, final long fill) {
            this.length = length;
            array = allocator.allocLong(length);
            if (fill == 0) {
                array.fill((byte) 0);
            } else {
                for (int i = 0; i < length; i++) {
                    array.setLong(i, fill);
                }
            }
        }

        public void free() {
            array.free();
        }

        private final long index(long idx) {
            if (idx < 0 || idx >= length) {
                throw new IndexOutOfBoundsException("index " + idx);
            }
            return idx;
        }

        public final long get(final int idx) {
            return array.getLongVolatile(index(idx));
        }

        public final void set(final int idx, final long value) {
            array.putLongVolatile(index(idx), value);
        }

        public final boolean compareAndSet(final int idx,
            final long expect, final long update)
        {
            return array.compareAndSetLong(index(idx), expect, update);
        }

        public final int length() {
            return length;
        }
    }

    public NativePositiveLong2IntHashMap() {
        this(DEFAULT_SIZE);
    }

    public NativePositiveLong2IntHashMap(final int length) {
        this.length = length;
        array = new NativeLongArray(length, 0);
        values = new NativeLongArray(length, NULL_VALUE);
        halfSize = length >> 1;
        maxSentinels = halfSize >> 1;
    }

    @Override
    public void close() {
        arrayLock.writeLock().lock();
        try {
            if (array == null) {
                return;
            }
            for (int i = 0; i < array.length; i++) {
                long key = array.get(i);
                if (key != NULL && key != SENTINEL) {
                    freeKey(key);
                }
            }
            Long key;
            while ((key = deleteQueue.poll()) != null) {
                freeKey(key);
            }
            if (array != null) {
                array.free();
                array = null;
            }
            if (values != null) {
                values.free();
                values = null;
            }
        } finally {
            arrayLock.writeLock().unlock();
        }
    }

    private static final int hash(int h) {
        h += (h <<  15) ^ 0xffffcd7d;
        h ^= (h >>> 10);
        h += (h <<   3);
        h ^= (h >>>  6);
        h += (h <<   2) + (h << 14);
        return h ^ (h >>> 16);
     }

    private final int posForHash(final int hashCode,
        final NativeLongArray array)
    {
        return hashCode & (array.length() - 1);
    }

    protected int keyHashCode(final long key) {
        return (int)key;
    }

    protected boolean keysEquals(final long a, final long b) {
        return a == b;
    }

    protected void freeKey(final long key) {
    }

    public boolean put(final long key, final int value) {
        byte result;
        do {
            while (rehashing.get() == NEED_REHASH) {
                Thread.yield();
            }
            result = tryPut(key, value);
        } while (result == RETRY);
        return result == TRUE;
    }

    private byte tryPut(final long key, final int value) {
        int fill = 0;
//        int scans = 0;
//        int restarts = 0;
//        int subscans = 0;
        arrayLock.readLock().lock();
        final int hashCode = hash(keyHashCode(key));
        int hashPos = posForHash(hashCode, array);
        try {
            while (true) {
                if (rehashing.get() != NO_REHASH) {
                    return RETRY;
                }
//                scans++;
                while (array.get(hashPos) == NULL) {
                    if (values.compareAndSet(hashPos, NULL_VALUE, value)
                        && array.compareAndSet(hashPos, NULL, key))
                    {
                        fill = arrayFill.incrementAndGet();
                        return TRUE;
                    } else {
                        System.err.println("value: " + values.get(hashPos));
                    }
                }
                long o = array.get(hashPos);
                if (o == SENTINEL) {
                    int savedPos = hashPos;
                    //scan until null to check dupplicates
                    while (rehashing.get() == NO_REHASH) {
                        hashPos = posForHash(hashPos + 1, array);
//                        subscans++;
                        o = array.get(hashPos);
                        if (o == NULL) {
                            break;
                        }
                        if (o != SENTINEL && keysEquals(o, key)) {
                            return FALSE;
                        }
                    }
                    if (rehashing.get() != NO_REHASH) {
                        continue;
                    }
                    //found null, try to replace SENTINEL
                    while (array.get(savedPos) == SENTINEL) {
                        if (values.compareAndSet(savedPos, NULL_VALUE, value)
                            && array.compareAndSet(savedPos, SENTINEL, key))
                        {
                            fill = arrayFill.incrementAndGet();
                            return TRUE;
                        }
                    }
                    //else SENTINEL was allready replaced
                    //restart
                    hashPos = posForHash(hashCode, array);
//                    restarts++;
                } else {
                    if (keysEquals(o, key)) {
                        return FALSE;
                    }
                    hashPos = posForHash(hashPos + 1, array);
                }
            }
        } finally {
            arrayLock.readLock().unlock();
            if (fill == halfSize) {
                rehash();
            }
        }
    }

    //doesn't issue equals(), used for rehash. Use under outer lock
    private void putNoEquals(
        final long key,
        final long value,
        final NativeLongArray array,
        final NativeLongArray values)
    {
        int hashCode = hash(keyHashCode(key));
        int hashPos = posForHash(hashCode, array);
        while (true) {
            if (array.compareAndSet(hashPos, NULL, key)) {
                values.set(hashPos, value);
                return;
            }
            hashPos = posForHash(hashPos + 1, array);
        }
    }

    public final long remove(final long key) {
        int sentinels = 0;
        int scans = 0;
        int restarts = 0;
        arrayLock.readLock().lock();
        int hashCode = hash(keyHashCode(key));
        int hashPos = posForHash(hashCode, array);
        try {
            while (true) {
                scans++;
                long o = array.get(hashPos);
                if (o == NULL) {
                    return NULL_VALUE;
                }
                if (o != SENTINEL) {
                    if (keysEquals(o, key)) {
                        if (array.compareAndSet(hashPos, o, SENTINEL)) {
                            arrayFill.decrementAndGet();
                            sentinels = this.sentinels.incrementAndGet();
//                            freeKey(o);
                            while (!deleteQueue.offer(o)) {
                                Long old = deleteQueue.poll();
                                if (old != null) {
                                    freeKey(old);
                                }
                            }
                            long value = values.get(hashPos);
                            values.set(hashPos, NULL_VALUE);;
                            return value;
                        }
                        //else retry with this hashCode
                        restarts++;
                        continue;
                    }
                }
                hashPos = posForHash(hashPos + 1, array);
            }
        } finally {
            arrayLock.readLock().unlock();
            if (sentinels == maxSentinels) {
                purgeSentinels();
            }
        }
    }

    private void rehash() {
        rehashing.set(NEED_REHASH);
        arrayLock.writeLock().lock();
        rehashing.set(REHASHING);
        try {
            while (arrayFill.get() >= length >> 1) {
                length <<= 1;
            }
            NativeLongArray newArray =
                new NativeLongArray(length, 0);
            NativeLongArray newValues =
                new NativeLongArray(length, NULL_VALUE);
            arrayFill.set(0);
            sentinels.set(0);
            halfSize = length >> 1;
            maxSentinels = halfSize >> 1;
            for (int i = 0; i < array.length(); i++) {
                long key = array.get(i);
                if (key == NULL || key == SENTINEL) {
                    continue;
                }
                long value = values.get(i);
                putNoEquals(key, value, newArray, newValues);
                arrayFill.incrementAndGet();
            }
            NativeLongArray toFree = array;
            NativeLongArray valuesToFree = values;
            array = newArray;
            values = newValues;
            toFree.free();
            valuesToFree.free();
        } finally {
            arrayLock.writeLock().unlock();
            rehashing.set(NO_REHASH);
        }
    }

    private void closeDeletion(final int delPos) {
        int d = delPos;
        // Adapted from Knuth Section 6.4 Algorithm R

        // Look for items to swap into newly vacated slot
        // starting at index immediately following deletion,
        // and continuing until a null slot is seen, indicating
        // the end of a run of possibly-colliding keys.
        long key;
        int i = posForHash(d + 1, array);
        key = array.get(i);
        while (key != NULL) {
            // The following test triggers if the item at slot i (which
            // hashes to be at slot r) should take the spot vacated by d.
            // If so, we swap it in, and then continue with d now at the
            // newly vacated i.  This process will terminate when we hit
            // the null slot at the end of this run.
            // The test is messy because we are using a circular table.
            if (key != SENTINEL) {
                final int r = posForHash(hash(keyHashCode(key)), array);
                if ((i < r && (r <= d || d <= i)) || (r <= d && d <= i)) {
                    array.set(d, key);
                    values.set(d, values.get(i));
                    array.set(i, NULL);
                    values.set(i, NULL_VALUE);
                    d = i;
                }
            }
            i = posForHash(i + 1, array);
            key = array.get(i);
        }
    }

    private void purgeSentinels() {
        arrayLock.writeLock().lock();
        try {
            for (int i = 0; i < length; i++) {
                long key = array.get(i);
                long value = values.get(i);
                if (key == SENTINEL) {
                    array.set(i, NULL);
                    values.set(i, NULL_VALUE);
                    closeDeletion(i);
                }
            }
            sentinels.set(0);
        } finally {
            arrayLock.writeLock().unlock();
        }
    }

    private void purgeSentinelsOld() {
        arrayLock.writeLock().lock();
        try {
            int sentinelsCleared = 0;
            for (int i = 0; i < length; i++) {
                if (array.compareAndSet(i, SENTINEL, NULL)) {
                    values.set(i, NULL_VALUE);
                    sentinelsCleared++;
                }
            }
            if (sentinelsCleared == 0) {
                return;
            }
            arrayFill.set(0);
            for (int i = 0; i < length; i++) {
                long key = array.get(i);
                long value = values.get(i);
                if (key == NULL) {
                    continue;
                }
                final int hashPos = posForHash(hash(keyHashCode(key)), array);
                if (hashPos != i) {
                    array.set(i, NULL);
                    values.set(i, NULL_VALUE);
                    //reinsert as sentinel holding path to that key was purged
                    putNoEquals(key, value, array, values);
                }
                arrayFill.incrementAndGet();
            }
            sentinels.set(0);
        } finally {
            arrayLock.writeLock().unlock();
        }
    }

    public final boolean contains(final long key) {
        arrayLock.readLock().lock();
        int hashCode = hash(keyHashCode(key));
        int hashPos = posForHash(hashCode, array);
        try {
            while (true) {
                long o = array.get(hashPos);
                if (o == NULL) {
                    return false;
                }
                if (o != SENTINEL) {
                    if (keysEquals(o, key)) {
                        return true;
                    }
                }
                hashPos = posForHash(hashPos + 1, array);
            }
        } finally {
            arrayLock.readLock().unlock();
        }
    }

    public final long get(final long key) {
        arrayLock.readLock().lock();
        int hashCode = hash(keyHashCode(key));
        int hashPos = posForHash(hashCode, array);
        try {
            while (true) {
                long o = array.get(hashPos);
                if (o == NULL) {
                    return NULL_VALUE;
                }
                if (o != SENTINEL) {
                    if (keysEquals(o, key)) {
                        long value = values.get(hashPos);
                        //double check
                        if (value == NULL_VALUE || array.get(hashPos) != o) {
                            //recheck
                            continue;
                        }
                        return value;
                    }
                }
                hashPos = posForHash(hashPos + 1, array);
            }
        } finally {
            arrayLock.readLock().unlock();
        }
    }

    public int sizeInBytes() {
        return length * 8;
    }

    public static void main(String[] args) throws Exception {
        final int TEST_MAX;
        if (args.length > 0) {
            TEST_MAX = Integer.parseInt(args[0]);
        } else {
            TEST_MAX = 1024 * 120;
        }
//        final int cpuCount = 1;
        final int cpuCount = Runtime.getRuntime().availableProcessors();
        ThreadPoolExecutor executor = new ThreadPoolExecutor(
            cpuCount,
            cpuCount,
            1,
            TimeUnit.MINUTES,
            new ArrayBlockingQueue<Runnable>(cpuCount));
        List<Future<Void>> futures = new ArrayList<Future<Void>>(cpuCount);

        final int perThread = (TEST_MAX / cpuCount);
        System.err.println("PerThread: " + perThread);
        final NativePositiveLong2IntHashMap testMap =
            new NativePositiveLong2IntHashMap(32);
        final ConcurrentHashMap<Long, Integer> stableMap =
            new ConcurrentHashMap<Long, Integer>();

        for (int i = 0; i < cpuCount; i++) {
            int start = perThread * i;
            int count = perThread;
            Future<Void> f = executor.submit(
                new TestChunk(start, count, testMap, stableMap));
            futures.add(f);
        }
        try {
            while (futures.size() > 0) {
                Iterator<Future<Void>> iter = futures.iterator();
                while (iter.hasNext()) {
                    Future<Void> f = iter.next();
                    try {
                        f.get(100, TimeUnit.MILLISECONDS);
                        iter.remove();
                    } catch (TimeoutException ign) {
                    } catch (Exception e) {
                        e.printStackTrace();
                        iter.remove();
                    }
                }
            }
            NativeMemory2.printStats();
        } finally {
            testMap.close();
            executor.shutdownNow();
        }
        NativeMemory2.printStats();
    }

    private static class TestChunk implements Callable<Void> {
        final int start;
        final int count;
        final NativePositiveLong2IntHashMap testMap;
        final ConcurrentHashMap<Long, Integer> stableMap;
        public TestChunk(final int start, final int count,
            final NativePositiveLong2IntHashMap testMap,
            final ConcurrentHashMap<Long, Integer> stableMap) {
            this.start = start + 1;
            this.count = count;
            this.testMap = testMap;
            this.stableMap = stableMap;
        }

        @Override
        public Void call() {
            long[] testData = new long[count];
            for (int i = 0; i < count; i++) {
                testData[i] = i + start;
            }
            //shuffle
            for (int i = count - 1; i > 3; i--) {
                int idx = (int)(Math.random() * (i));
                long tmp = testData[idx];
                testData[idx] = testData[i];
                testData[i] = tmp;
            }
            System.err.println(Thread.currentThread().getId() + ": test data generated");

            //sequence tests

            System.err.println(Thread.currentThread().getId() +": Add all");
            for (int i = 0; i < count; i++) {
                if (!testMap.put(testData[i], (int) testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Add fail: " + i
                        + ", value=" + testData[i]);
                }
            }
            System.err.println(Thread.currentThread().getId() + ": test all");
            for (int i = 0; i < count; i++) {
                if (!testMap.contains(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Contains fail: " + i
                        + ", value=" + testData[i]);
                }
                if (testMap.get(testData[i]) != testData[i]) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Get fail: " + i
                        + ", value=" + testData[i]);
                }
            }

            //remove half
            System.err.println(Thread.currentThread().getId() + ": remove half");
            for (int i = 0; i < count >> 1; i++) {
                if (testMap.remove(testData[i]) != testData[i]) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Can't remove: " + i + ", value="
                        + testData[i] + ", contains="
                        + testMap.contains(testData[i]));
                }
            }

            System.err.println(Thread.currentThread().getId() + ": test half");
            //test
            for (int i = count >> 1; i < count; i++) {
                if (!testMap.contains(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Contains fail: " + i
                        + ", value=" + testData[i]);
                }
                if (testMap.get(testData[i]) != testData[i]) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Get fail: " + i
                        + ", value=" + testData[i]);
                }
            }

            System.err.println(Thread.currentThread().getId() + ": add half");
            //add half
            for (int i = 0; i < count >> 1; i++) {
                if (!testMap.put(testData[i], (int) testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Add failed: " + i + ", value="
                        + testData[i] + ", contains="
                        + testMap.contains(testData[i]));
                }
            }

            System.err.println(Thread.currentThread().getId() + ": add conflicted");
            //add half
            for (int i = 0; i < count >> 1; i++) {
                if (testMap.put(testData[i], (int) testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Add failed: " + i + ", value="
                        + testData[i] + ", contains="
                        + testMap.contains(testData[i]));
                }
            }

            System.err.println(Thread.currentThread().getId() + ": remove half");
            //remove half
            for (int i = 0; i < count >> 1; i++) {
                if (testMap.remove(testData[i]) != testData[i]) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Souldn't remove: " + i
                        + ", value=" + testData[i] + ", contains="
                        + testMap.contains(testData[i]));
                }
            }

            System.err.println(Thread.currentThread().getId() + ": remove half again");
            //remove half
            for (int i = 0; i < count >> 1; i++) {
                if (testMap.remove(testData[i]) != NULL_VALUE) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Souldn't remove: " + i
                        + ", value=" + testData[i] + ", contains="
                        + testMap.contains(testData[i]));
                }
            }

            System.err.println(Thread.currentThread().getId() + ": test half");
            //test
            for (int i = count >> 1; i < count; i++) {
                if (!testMap.contains(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Contains fail: " + i
                        + ", value=" + testData[i]);
                }
                if (testMap.get(testData[i]) != testData[i]) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Get fail: " + i
                        + ", value=" + testData[i]);
                }
            }

            System.err.println(Thread.currentThread().getId() + ": test second half");
            //test
            for (int i = 0; i < count >> 1; i++) {
                if (testMap.contains(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Contains fail: " + i
                        + ", value=" + testData[i]);
                }
                long value = testMap.get(testData[i]);
                if (value != NULL_VALUE) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Get fail: " + i
                        + ", expected=" + NULL_VALUE + ", value=" + value);
                }
            }

            //add half
            System.err.println(Thread.currentThread().getId() + ": add half");
            for (int i = 0; i < count >> 1; i++) {
                if (!testMap.put(testData[i], (int) testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Add failed: " + i
                        + ", value=" + testData[i] + ", contains="
                        + testMap.contains(testData[i]));
                }
            }
            System.err.println(Thread.currentThread().getId() + ": finished");
            return null;
        }
    }
}
