package ru.yandex.msearch.util;

import java.nio.charset.StandardCharsets;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.concurrent.locks.ReentrantReadWriteLock;

public class ConcurrentHashSet<T> implements Iterable<T> {
    private static final int DEFAULT_SIZE = 1024;
    private static final Object SENTINEL = new Object();

    private static final int NO_REHASH = 0;
    private static final int NEED_REHASH = 1;
    private static final int REHASHING = 2;

    private static final byte FALSE = 0;
    private static final byte TRUE = 1;
    private static final byte RETRY = 2;

    private final ReentrantReadWriteLock arrayLock =
        new ReentrantReadWriteLock();
    private AtomicReferenceArray array;
    private final AtomicInteger sentinels = new AtomicInteger(0);
    private final AtomicInteger arrayFill = new AtomicInteger(0);
    private final AtomicInteger rehashing = new AtomicInteger(NO_REHASH);
    private final boolean noRehash;
    private int halfSize;
    private int maxSentinels;
    private int length;

    public ConcurrentHashSet() {
        this(DEFAULT_SIZE);
    }

    public ConcurrentHashSet(final int length) {
        this.length = length;
        array = new AtomicReferenceArray(length);
        halfSize = length >> 1;
        maxSentinels = halfSize >> 1;
        noRehash = false;
    }

    private ConcurrentHashSet(final int length, final boolean noRehash) {
        this.length = length;
        this.noRehash = noRehash;
        array = new AtomicReferenceArray(length);
        halfSize = length >> 1;
        maxSentinels = halfSize >> 1;
    }

    private static final int hashOld(int h) {
         h ^= (h >>> 20) ^ (h >>> 12);
         return h ^ (h >>> 7) ^ (h >>> 4);
    }

    private static final int hash(int h) {
        h += (h <<  15) ^ 0xffffcd7d;
        h ^= (h >>> 10);
        h += (h <<   3);
        h ^= (h >>>  6);
        h += (h <<   2) + (h << 14);
        return h ^ (h >>> 16);
     }

    private final int posForHash(final int hashCode,
        final AtomicReferenceArray array)
    {
        return hashCode & (array.length() - 1);
//        return Math.abs(hashCode % array.length());
    }

    public boolean put(final T key) {
        byte result;
        do {
            while (rehashing.get() == NEED_REHASH) {
                Thread.yield();
            }
            result = tryPut(key);
        } while (result == RETRY);
        return result == TRUE;
    }

    private byte tryPut(final T key) {
        int fill = 0;
//        int scans = 0;
//        int restarts = 0;
//        int subscans = 0;
        arrayLock.readLock().lock();
        final int hashCode = hash(key.hashCode());
        int hashPos = posForHash(hashCode, array);
        try {
            while (true) {
                if (rehashing.get() != NO_REHASH) {
                    return RETRY;
                }
//                scans++;
                if (array.compareAndSet(hashPos, null, key)) {
                    fill = arrayFill.incrementAndGet();
                    return TRUE;
                }
                Object o = array.get(hashPos);
                if (o == SENTINEL) {
                    int savedPos = hashPos;
                    //scan until null to check dupplicates
                    while (rehashing.get() == NO_REHASH) {
                        hashPos = posForHash(hashPos + 1, array);
//                        subscans++;
                        o = array.get(hashPos);
                        if (o == null) {
                            break;
                        }
                        T oldKey = (T)o;
                        if (oldKey.equals(key)) {
                            return FALSE;
                        }
                    }
                    if (rehashing.get() != NO_REHASH) {
                        continue;
                    }
                    //found null, try to replace SENTINEL
                    if (array.compareAndSet(savedPos, SENTINEL, key)) {
                        fill = arrayFill.incrementAndGet();
                        return TRUE;
                    }
                    //else SENTINEL was allready replaced
                    //restart
                    hashPos = posForHash(hashCode, array);
                } else {
                    T oldKey = (T)o;
                    if (oldKey.equals(key)) {
                        return FALSE;
                    }
                    hashPos = posForHash(hashPos + 1, array);
                }
            }
        } finally {
            arrayLock.readLock().unlock();
            if (fill == halfSize && !noRehash) {
                rehash();
            }
        }
    }

    //doesn't issue equals(), used for rehash. Use under outer lock
    private void putNoEquals(final T key,
        final AtomicReferenceArray array)
    {
        int hashCode = hash(key.hashCode());
        int hashPos = posForHash(hashCode, array);
        while (true) {
            if (array.compareAndSet(hashPos, null, key)) {
                return;
            }
            hashPos = posForHash(hashPos + 1, array);
        }
    }

    public final boolean remove(final T key) {
        int sentinels = 0;
        int scans = 0;
        int restarts = 0;
        arrayLock.readLock().lock();
        int hashCode = hash(key.hashCode());
        int hashPos = posForHash(hashCode, array);
        try {
            while (true) {
                scans++;
                Object o = array.get(hashPos);
                if (o == null) {
                    return false;
                }
                if (o != SENTINEL) {
                    T oldKey = (T)o;
                    if (oldKey.equals(key)) {
                        if (array.compareAndSet(hashPos, o, SENTINEL)) {
                            arrayFill.decrementAndGet();
                            sentinels = this.sentinels.incrementAndGet();
                            return true;
                        }
                        //else retry with this hashCode
                        restarts++;
                        continue;
                    }
                }
                hashPos = posForHash(hashPos + 1, array);
            }
        } finally {
            arrayLock.readLock().unlock();
            if (sentinels == maxSentinels) {
                purgeSentinels();
            }
        }
    }

    private void rehash() {
        rehashing.set(NEED_REHASH);
        arrayLock.writeLock().lock();
        rehashing.set(REHASHING);
        try {
            while (arrayFill.get() >= length >> 1) {
                length <<= 1;
            }
            AtomicReferenceArray<T> newArray =
                new AtomicReferenceArray<T>(length);
            arrayFill.set(0);
            sentinels.set(0);
            halfSize = length >> 1;
            maxSentinels = halfSize >> 1;
            for (int i = 0; i < array.length(); i++) {
                Object o = array.get(i);
                if (o == null || o == SENTINEL) {
                    continue;
                }
                T key = (T)array.get(i);
                putNoEquals(key, newArray);
                arrayFill.incrementAndGet();
            }
            array = newArray;
        } finally {
            arrayLock.writeLock().unlock();
            rehashing.set(NO_REHASH);
        }
    }

    private void closeDeletion(final int delPos) {
        int d = delPos;
        // Adapted from Knuth Section 6.4 Algorithm R

        // Look for items to swap into newly vacated slot
        // starting at index immediately following deletion,
        // and continuing until a null slot is seen, indicating
        // the end of a run of possibly-colliding keys.
        Object key;
        int i = posForHash(d + 1, array);
        key = array.get(i);
        while (key != null) {
            // The following test triggers if the item at slot i (which
            // hashes to be at slot r) should take the spot vacated by d.
            // If so, we swap it in, and then continue with d now at the
            // newly vacated i.  This process will terminate when we hit
            // the null slot at the end of this run.
            // The test is messy because we are using a circular table.
            if (key != SENTINEL) {
                final int hashCode = key.hashCode();
                final int hash = hash(hashCode);
                final int r = posForHash(hash, array);
                if ((i < r && (r <= d || d <= i)) || (r <= d && d <= i)) {
                    array.set(d, key);
                    array.set(i, null);
                    d = i;
                }
            }
            i = posForHash(i + 1, array);
            key = array.get(i);
        }
    }

    private void purgeSentinels() {
        arrayLock.writeLock().lock();
        try {
            for (int i = 0; i < length; i++) {
                Object key = array.get(i);
                if (key == SENTINEL) {
                    array.set(i, null);
                    closeDeletion(i);
                }
            }
            sentinels.set(0);
        } finally {
            arrayLock.writeLock().unlock();
        }
    }

    public final boolean contains(final T key) {
        arrayLock.readLock().lock();
        int hashCode = hash(key.hashCode());
        int hashPos = posForHash(hashCode, array);
        try {
            while (true) {
                Object o = array.get(hashPos);
                if (o == null) {
                    return false;
                }
                if (o != SENTINEL) {
                    T oldKey = (T)o;
                    if (oldKey.equals(key)) {
                        return true;
                    }
                }
                hashPos = posForHash(hashPos + 1, array);
            }
        } finally {
            arrayLock.readLock().unlock();
        }
    }

    public int sizeInBytes() {
        return length * 8;
    }

    public Iterator<T> iterator() {
        return new KeyIterator(array);
    }

    private class KeyIterator implements Iterator<T> {
        private final AtomicReferenceArray array;
        private int pos;
        private T key;

        KeyIterator(final AtomicReferenceArray array) {
            this.array = array;
            pos = 0;
        }

        public boolean hasNext() {
            if (key != null) {
                return true;
            }
            while (pos < array.length()) {
                key = (T) array.get(pos++);
                if (key != null && key != SENTINEL) {
                    break;
                }
            }
            return key != null;
        }

        public T next() {
            if (hasNext()) {
                T key = this.key;
                this.key = null;
                return key;
            } else {
                throw new NoSuchElementException();
            }
        }
    }

    private static class BytesRef {
        private static final int HASH_PRIME = 31;
        private final byte[] data;
        private final int offset;
        private final int length;

        BytesRef(final byte[] data, final int offset, final int length) {
            this.data = data;
            this.offset = offset;
            this.length = length;
        }

        BytesRef(final byte[] data) {
            this.data = data;
            this.offset = 0;
            this.length = data.length;
        }

        public byte[] bytes() {
            return data;
        }

        public int offset() {
            return offset;
        }

        public int length() {
            return length;
        }

        public byte byteAt(final int i) {
            return data[i + offset];
        }

        public String toString() {
            return new String(data, offset, length, StandardCharsets.UTF_8);
        }

        @Override
        public int hashCode() {
            int result = 0;
            final int end = offset + length;
            for (int i = offset; i < end;i++) {
                result = HASH_PRIME * result + data[i];
            }
            return result;
        }

        @Override
        public boolean equals(Object o) {
            BytesRef other = (BytesRef) o;
            if (length == other.length) {
                int otherUpto = other.offset;
                final byte[] otherBytes = other.data;
                final int end = offset + length;
                for(int upto=offset;upto<end;upto++,otherUpto++) {
                    if (data[upto] != otherBytes[otherUpto]) {
                        return false;
                    }
                }
                return true;
            } else {
                return false;
            }
        }
    }

    public static void main(String[] args) throws Exception {
        ConcurrentHashSet<BytesRef> set = new ConcurrentHashSet<>(4096, true);
        int maxSize = 4096;
        for (int i = 0; i < maxSize; i++) {
            String sKey = "KEY" + i;
            BytesRef key = new BytesRef(sKey.getBytes());
            if (!set.put(key)) {
                throw new RuntimeException("Can't add key: " + key);
            }
        }
        for (int i = 0; i < maxSize; i++) {
            String sKey = "KEY" + i;
            BytesRef key = new BytesRef(sKey.getBytes());
            if (!set.contains(key)) {
                throw new RuntimeException("Contains fail for key: " + key);
            }
        }
        for (int i = 0; i < maxSize; i++) {
            String sKey = "KEY" + i;
            BytesRef key = new BytesRef(sKey.getBytes());
            System.err.println("Removing: " + key);
            if (!set.remove(key)) {
                throw new RuntimeException("Remove failed for key: " + key);
            }

            for (int j = i + 1; j < maxSize; j++) {
                String sKey2 = "KEY" + j;
                BytesRef key2 = new BytesRef(sKey2.getBytes());
                if (!set.contains(key2)) {
                    throw new RuntimeException("Contains fail for key2: " + key2);
                }
            }

        }
        if (true) {
        final int TEST_MAX;
        if (args.length > 0) {
            TEST_MAX = Integer.parseInt(args[0]);
        } else {
            TEST_MAX = 1024 * 120;
        }
        final int cpuCount = Runtime.getRuntime().availableProcessors();
        ThreadPoolExecutor executor = new ThreadPoolExecutor(
            cpuCount,
            cpuCount,
            1,
            TimeUnit.MINUTES,
            new ArrayBlockingQueue<Runnable>(cpuCount));
        List<Future<Void>> futures = new ArrayList<Future<Void>>(cpuCount);

        final int perThread = (TEST_MAX / cpuCount);
        System.err.println("PerThread: " + perThread);
        final ConcurrentHashSet<Long> testSet = new ConcurrentHashSet<Long>(32);
        final ConcurrentHashMap<Long,Long> stableSet =
            new ConcurrentHashMap<Long,Long>();

        for (int i = 0; i < cpuCount; i++) {
            int start = perThread * i;
            int count = perThread;
            Future<Void> f = executor.submit(
                new TestChunk(start, count, testSet, stableSet));
            futures.add(f);
        }
        try {
            while (futures.size() > 0) {
                Iterator<Future<Void>> iter = futures.iterator();
                while (iter.hasNext()) {
                    Future<Void> f = iter.next();
                    try {
                        f.get(100, TimeUnit.MILLISECONDS);
                        iter.remove();
                    } catch (TimeoutException ign) {
                    } catch (Exception e) {
                        e.printStackTrace();
                        iter.remove();
                    }
                }
            }
        } finally {
            executor.shutdownNow();
        }
        }
    }

    private static class TestChunk implements Callable<Void> {
        final int start;
        final int count;
        final ConcurrentHashSet<Long> testSet;
        final ConcurrentHashMap<Long,Long> stableSet;
        public TestChunk(final int start, final int count,
            final ConcurrentHashSet<Long> testSet,
            final ConcurrentHashMap<Long,Long> stableSet) {
            this.start = start;
            this.count = count;
            this.testSet = testSet;
            this.stableSet = stableSet;
        }

        @Override
        public Void call() {
            long[] testData = new long[count];
            for (int i = 0; i < count; i++) {
                testData[i] = i + start;
            }
            //shuffle
            for (int i = count - 1; i > 3; i--) {
                int idx = (int)(Math.random() * (i));
                long tmp = testData[idx];
                testData[idx] = testData[i];
                testData[i] = tmp;
            }
            System.err.println(Thread.currentThread().getId() + ": test data generated");

            //sequence tests

            System.err.println(Thread.currentThread().getId() +": Add all");
            for (int i = 0; i < count; i++) {
                if (!testSet.put(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Add fail: " + i
                        + ", value=" + testData[i]);
                }
            }
            System.err.println(Thread.currentThread().getId() + ": test all");
            for (int i = 0; i < count; i++) {
                if (!testSet.contains(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Contains fail: " + i
                        + ", value=" + testData[i]);
                }
            }

            //remove half
            System.err.println(Thread.currentThread().getId() + ": remove half");
            for (int i = 0; i < count >> 1; i++) {
                if (!testSet.remove(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Can't remove: " + i + ", value="
                        + testData[i] + ", contains="
                        + testSet.contains(testData[i]));
                }
            }

            System.err.println(Thread.currentThread().getId() + ": test half");
            //test
            for (int i = count >> 1; i < count; i++) {
                if (!testSet.contains(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Contains fail: " + i
                        + ", value=" + testData[i]);
                }
            }

            System.err.println(Thread.currentThread().getId() + ": add half");
            //add half
            for (int i = 0; i < count >> 1; i++) {
                if (!testSet.put(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Add failed: " + i + ", value="
                        + testData[i] + ", contains="
                        + testSet.contains(testData[i]));
                }
            }

            System.err.println(Thread.currentThread().getId() + ": add conflicted");
            //add half
            for (int i = 0; i < count >> 1; i++) {
                if (testSet.put(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Add failed: " + i + ", value="
                        + testData[i] + ", contains="
                        + testSet.contains(testData[i]));
                }
            }

            System.err.println(Thread.currentThread().getId() + ": remove half");
            //remove half
            for (int i = 0; i < count >> 1; i++) {
                if (!testSet.remove(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Souldn't remove: " + i
                        + ", value=" + testData[i] + ", contains="
                        + testSet.contains(testData[i]));
                }
            }

            System.err.println(Thread.currentThread().getId() + ": remove half again");
            //remove half
            for (int i = 0; i < count >> 1; i++) {
                if (testSet.remove(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Souldn't remove: " + i
                        + ", value=" + testData[i] + ", contains="
                        + testSet.contains(testData[i]));
                }
            }

            System.err.println(Thread.currentThread().getId() + ": test half");
            //test
            for (int i = count >> 1; i < count; i++) {
                if (!testSet.contains(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Contains fail: " + i
                        + ", value=" + testData[i]);
                }
            }

            System.err.println(Thread.currentThread().getId() + ": test second half");
            //test
            for (int i = 0; i < count >> 1; i++) {
                if (testSet.contains(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Contains fail: " + i
                        + ", value=" + testData[i]);
                }
            }

            //add half
            System.err.println(Thread.currentThread().getId() + ": add half");
            for (int i = 0; i < count >> 1; i++) {
                if (!testSet.put(testData[i])) {
                    throw new RuntimeException(Thread.currentThread().getId() + ": Add failed: " + i
                        + ", value=" + testData[i] + ", contains="
                        + testSet.contains(testData[i]));
                }
            }
            System.err.println(Thread.currentThread().getId() + ": finished");
            return null;
        }
    }
}
