package ru.yandex.msearch.util;

import java.nio.charset.StandardCharsets;

import ru.yandex.collection.ChunkedIntList;

import ru.yandex.util.unicode.ByteSequence;

public class ChunkedHashMap {
    private static final boolean debug = false;
//    private static final boolean debug = true;
    private static final int HASH_PRIME = 31;
    private static final int INITIAL_SIZE = 16;
    private static final int INT_SIZE = 4;
    private static final int HASH_OFFSET = 0;
    private static final int KEY_OFFSET = 4;
    private static final int KEY_DATA_OFFSET = 8;
    private static final int EMPTY = -1;

    private ChunkedIntList ptrTable;
    private DataMemory data;
    private ReusableSequence reusableKey;
    private int size;
    private int fill;
    private int wastedMemory;
    private int usage = 0;

    public ChunkedHashMap() {
        size = INITIAL_SIZE;
        ptrTable = new ChunkedIntList(size);
        for (int i = 0; i < size; i++) {
            ptrTable.add(EMPTY);
        }
        data = new DataMemory();
        reusableKey = new ReusableSequence(data);
        fill = 0;
        wastedMemory = 0;
    }

    private final int hashCode(ByteSequence seq) {
        final byte[] bytes = seq.bytes();
        int result = 0;
        final int end = seq.offset() + seq.length();
        for (int i = seq.offset(); i < end;i++) {
            result = HASH_PRIME * result + bytes[i];
        }
        return result;
    }

    private static final int hash(int h) {
        h += (h <<  15) ^ 0xffffcd7d;
        h ^= (h >>> 10);
        h += (h <<   3);
        h ^= (h >>>  6);
        h += (h <<   2) + (h << 14);
        return h ^ (h >>> 16);
     }

    private final int posForHash(final int hashCode, final int size) {
        return Math.abs(hashCode % size);
    }

    private boolean put(
        final int hashCode,
        final ByteSequence key,
        final int newDataPtr,
        final ChunkedIntList ptrTable)
    {
        final int size = ptrTable.size();
        final int hash = hash(hashCode);
        if (debug) System.err.println("put: hashCode: " + hashCode);
        if (debug) System.err.println("put: hash: " + hash);
        int hashPos = posForHash(hash, size);
        if (debug) System.err.println("put: posForHash: " + hashPos);
        boolean result;
        while (true) {
            int dataPtr = ptrTable.getInt(hashPos);
            //empty slot
            if (dataPtr == EMPTY) {
                ptrTable.setInt(hashPos, newDataPtr);
                result = true;
                break;
            }
            //old key equals new
            if (compareKey(hashCode, key, dataPtr)) {
                result = false;
                break;
            }
            //advance to next slot
            hashPos = posForHash(hashPos + 1, size);
            if (debug) System.err.println("put.retry: posForHash: " + hashPos);
        }
        return result;
    }

    private boolean putNoEquals(
        final int hashCode,
        final int newDataPtr,
        final ChunkedIntList ptrTable)
    {
        final int size = ptrTable.size();
        final int hash = hash(hashCode);
        if (debug) System.err.println("putNE: hashCode: " + hashCode);
        if (debug) System.err.println("putNE: hash: " + hash);
        int hashPos = posForHash(hash, size);
        if (debug) System.err.println("putNE: posForHash: " + hashPos);
        boolean result;
        while (true) {
            int dataPtr = ptrTable.getInt(hashPos);
            //empty slot
            if (dataPtr == EMPTY) {
                ptrTable.setInt(hashPos, newDataPtr);
                result = true;
                break;
            }
            //advance to next slot
            hashPos = posForHash(hashPos + 1, size);
            if (debug) System.err.println("putNE.retry: posForHash: " + hashPos);
        }
        return result;
    }

    public synchronized boolean put(
        final ByteSequence key,
        final ByteSequence value)
    {
        final int hashCode = hashCode(key);
        final int dataPtr = data.size();
        boolean result = put(hashCode, key, dataPtr, ptrTable);
        if (result) {
            data.writeData(dataPtr, hashCode, key, value);
            usage += memoryUsage(dataPtr);
            fill++;
        }
        if (fill > size >> 1) {
            rehash(size << 1);
        }
        return result;
    }

    public ByteSequence get(final ByteSequence key) {
        return get(key, new ReusableSequence(data));
    }

    public synchronized ByteSequence get(
        final ByteSequence key,
        final ByteSequence reuse)
    {
        final int hashCode = hashCode(key);
        final int hash = hash(hashCode);
        int hashPos = posForHash(hash, size);
        if (debug) System.err.println("get: hashCode: " + hashCode);
        if (debug) System.err.println("get: hash: " + hash);
        if (debug) System.err.println("get: posForHash: " + hashPos);
        while (true) {
            int dataPtr = ptrTable.getInt(hashPos);
            if (debug) System.err.println("get: dataPtr: " + dataPtr);
            if (dataPtr == EMPTY) {
                return null;
            }
            if (compareKey(hashCode, key, dataPtr)) {
                return setValueRef(reuse, dataPtr);
            }
            hashPos = posForHash(hashPos + 1, size);
            if (debug) System.err.println("get.retry: posForHash: " + hashPos);
        }
    }

    public synchronized boolean remove(final ByteSequence key) {
        final int hashCode = hashCode(key);
        final int hash = hash(hashCode);
        int hashPos = posForHash(hash, size);
        if (debug) System.err.println("remove: hashCode: " + hashCode);
        if (debug) System.err.println("remove: hash: " + hash);
        if (debug) System.err.println("remove: posForHash: " + hashPos);
        while (true) {
            int dataPtr = ptrTable.getInt(hashPos);
            if (debug) System.err.println("remove: dataPtr: " + dataPtr);
            if (dataPtr == EMPTY) {
                return false;
            }
            if (compareKey(hashCode, key, dataPtr)) {
                ptrTable.setInt(hashPos, EMPTY);
                wastedMemory += memoryUsage(dataPtr);
                int keyLen = data.readInt(dataPtr + KEY_OFFSET);
                data.writeInt(dataPtr + KEY_OFFSET, -1);
                data.writeInt(dataPtr, keyLen);
                fill--;
                closeDeletion(hashPos);
//                System.err.println(
//                    "removed: hashCode: " + hashCode + ", ptr: " + dataPtr);
                if (wastedMemory > data.size() >> 1) {
                    defragmentData();
                }
                return true;
            }
            hashPos = posForHash(hashPos + 1, size);
            if (debug) System.err.println("remove.retry: posForHash: " + hashPos);
        }
    }

    private void closeDeletion(final int delPos) {
        int d = delPos;
        // Adapted from Knuth Section 6.4 Algorithm R

        // Look for items to swap into newly vacated slot
        // starting at index immediately following deletion,
        // and continuing until a null slot is seen, indicating
        // the end of a run of possibly-colliding keys.
        Object item;
        int i = posForHash(d + 1, size);
        int ptr = ptrTable.getInt(i);
        while (ptr != EMPTY) {
            // The following test triggers if the item at slot i (which
            // hashes to be at slot r) should take the spot vacated by d.
            // If so, we swap it in, and then continue with d now at the
            // newly vacated i.  This process will terminate when we hit
            // the null slot at the end of this run.
            // The test is messy because we are using a circular table.
            final int hashCode = data.readInt(ptr + HASH_OFFSET);
            final int hash = hash(hashCode);
            final int r = posForHash(hash, size);
            if ((i < r && (r <= d || d <= i)) || (r <= d && d <= i)) {
                ptrTable.setInt(d, ptr);
                ptrTable.setInt(i, EMPTY);
                d = i;
            }
            i = posForHash(i + 1, size);
            ptr = ptrTable.getInt(i);
        }
    }

    private int memoryUsage(final int ptr) {
        int usage = INT_SIZE; //hash
        usage += INT_SIZE; //key length
        int keyLength = data.readInt(ptr + KEY_OFFSET);
        usage += keyLength;
        usage += INT_SIZE; //data length
        int dataLen = data.readInt(ptr + KEY_DATA_OFFSET + keyLength);
        usage += dataLen;
        return usage;
    }


    private ByteSequence setValueRef(
        final ByteSequence out,
        final int dataPtr)
    {
        final ByteSequence ret;
        final int keyLength = data.readInt(dataPtr + KEY_OFFSET);
        final int dataLengthOffset = dataPtr + KEY_DATA_OFFSET + keyLength;
        final int dataLength = data.readInt(dataLengthOffset);
        final int dataOffset = dataLengthOffset + INT_SIZE;
        if (out instanceof ReusableSequence
            && ((ReusableSequence) out).data == data)
        {
            ret = out;
        } else {
            ret = new ReusableSequence(data);
        }
        ((ReusableSequence) ret).set(dataOffset, dataLength);
        return ret;
    }

    private boolean compareKey(
        final int hashCode,
        final ByteSequence key,
        final int dataPtr)
    {
        int storedHashCode = data.readInt(dataPtr + HASH_OFFSET);
        if (debug) System.err.println("compreKey: storedHashCode: " + storedHashCode);
        if (storedHashCode != hashCode) {
            return false;
        }
        int keyLength = data.readInt(dataPtr + KEY_OFFSET);
        if (debug) System.err.println("compreKey: keyLength: " + keyLength);
        reusableKey.set(dataPtr + KEY_DATA_OFFSET, keyLength);
        return reusableKey.equals(key);
    }


    private void rehash(final int newSize) {
        if (debug) {
            System.err.println("Rehash: " + newSize);
        }
        ChunkedIntList newPtrTable = new ChunkedIntList(newSize);
        for (int i = 0; i < newSize; i++) {
            newPtrTable.add(EMPTY);
        }
        ReusableSequence key = new ReusableSequence(data);
        fill = 0;
        for (int i = 0; i < size; i++) {
            int dataPtr = ptrTable.getInt(i);
            if (dataPtr != EMPTY) {
                key.set(
                    dataPtr + KEY_DATA_OFFSET,
                    data.readInt(dataPtr + KEY_OFFSET));
                int hashCode = data.readInt(dataPtr + HASH_OFFSET);
                putNoEquals(hashCode, dataPtr, newPtrTable);
                fill++;
            }
        }
        ptrTable = newPtrTable;
        size = newSize;
    }

    private synchronized void defragmentDataSlow() {
        if (debug) {
            System.err.println("DEFRAGMENT: size=" + data.size()
                + ", wasted: " + wastedMemory
                + ", tableSize: " + size);
        }
        final DataMemory newData = new DataMemory();
        for (int i = 0; i < size; i++) {
            int oldPtr = ptrTable.getInt(i);
            if (oldPtr != -1) {
                int newPtr = newData.copy(data, oldPtr);
                ptrTable.setInt(i, newPtr);
            }
        }
        data = newData;
        reusableKey = new ReusableSequence(data);
        wastedMemory = 0;
        if (debug) {
            System.err.println("DEFRAGMENT.end: size=" + data.size()
                + ", wasted: " + wastedMemory);
        }
    }

    private int posForPtr(final int ptr, final int hashCode) {
        final int hash = hash(hashCode);
        int hashPos = posForHash(hash, size);
        while (true) {
            int dataPtr = ptrTable.getInt(hashPos);
            if (dataPtr == -1) {
                throw new RuntimeException("Failed to find slot for data "
                    + "offset: " + ptr +", with hashcode: " + hashCode);
            } else if(dataPtr == ptr) {
                return hashPos;
            }
            hashPos = posForHash(hashPos + 1, size);
        }
    }

    private synchronized void defragmentData() {
        if (debug) {
            System.err.println("DEFRAGMENT: size=" + data.size()
                + ", wasted: " + wastedMemory
                + ", tableSize: " + size
                + ", fill: " + fill);
        }
        final DataMemory newData = new DataMemory();
        int ptr = 0;
        int prevNewPtr = 0;
        for (int i = 0; i < fill;) {
            int hashCode = data.readInt(ptr);
            int keyLen = data.readInt(ptr + KEY_OFFSET);
            if (keyLen == -1) {
                //removed entry
                keyLen = hashCode;
                ptr += KEY_DATA_OFFSET + keyLen;
                int dataSize = data.readInt(ptr);
                ptr += INT_SIZE + dataSize;
//                System.err.println("skipping: " + ptr);
            } else {
//                System.err.println("moving ptr: " + ptr);
                int slot = posForPtr(ptr, hashCode);
                int newPtr = newData.copy(data, ptr);
                int copied = newData.size() - newPtr;
                ptrTable.setInt(slot, newPtr);
                ptr += copied;
                i++;
            }
        }
//        System.err.println("ptr: " + ptr);
        data = newData;
        reusableKey = new ReusableSequence(data);
        wastedMemory = 0;
        if (debug) {
            System.err.println("DEFRAGMENT.end: size=" + data.size()
                + ", wasted: " + wastedMemory);
        }
    }

    private static final class ReusableSequence implements ByteSequence {
        private final ChunkedByteArray data;
        private int offset;
        private int length;

        ReusableSequence(final ChunkedByteArray data) {
            this.data = data;
        }

        public void set(final int offset, final int length) {
            this.offset = offset;
            this.length = length;
        }

        @Override
        public byte[] bytes() {
            return null;
        }

        @Override
        public int offset() {
            return offset;
        }

        @Override
        public int length() {
            return length;
        }

        @Override
        public byte byteAt(final int i) {
            return data.getByte(i + offset);
        }

        @Override
        public String toString() {
            byte[] bytes = new byte[length];
            if (debug) System.err.println("toString: l: " + length + ", o: " + offset);
            for (int i = 0; i < length; i++) {
                bytes[i] = data.getByte(i + offset);
            }
            return new String(bytes, 0, length, StandardCharsets.UTF_8);
        }

        public boolean equals(final ByteSequence other) {
            boolean retval = false;
            if (debug) System.err.println("equals: l: " + length);
            if (debug) System.err.println("equals: o.l: " + other.length());
            if (length == other.length()) {
                retval = true;
                for (int i = 0, j = offset; i < length; i++, j++) {
                    if (data.getByte(j) != other.byteAt(i)) {
                        retval = false;
                        break;
                    }
                }
            }
            return retval;
        }
    }

    private static class BytesRef implements ByteSequence {
        private final byte[] data;
        private final int offset;
        private final int length;

        BytesRef(final byte[] data, final int offset, final int length) {
            this.data = data;
            this.offset = offset;
            this.length = length;
        }

        BytesRef(final byte[] data) {
            this.data = data;
            this.offset = 0;
            this.length = data.length;
        }

        @Override
        public byte[] bytes() {
            return data;
        }

        @Override
        public int offset() {
            return offset;
        }

        @Override
        public int length() {
            return length;
        }

        @Override
        public byte byteAt(final int i) {
            return data[i + offset];
        }

        @Override
        public String toString() {
            return new String(data, offset, length, StandardCharsets.UTF_8);
        }
    }

    private static class DataMemory extends ChunkedByteArray {
        public int readInt(final int ptr) {
            return ((getByte(ptr) & 0xFF) << 24)
                | ((getByte(ptr + 1) & 0xFF) << 16)
                | ((getByte(ptr + 2) & 0xFF) <<  8)
                | (getByte(ptr + 3) & 0xFF);
        }

        public void writeInt(final int ptr, final int value) {
            writeByte(ptr, (byte) (value >> 24));
            writeByte(ptr + 1, (byte) (value >> 16));
            writeByte(ptr + 2, (byte) (value >> 8));
            writeByte(ptr + 3, (byte) value);
        }

        public void writeByte(final int ptr, final byte b) {
            if (ptr == size()) {
                addByte(b);
            } else {
                int prealloc = ptr - size();
                while (prealloc-- >= 0) {
                    addByte((byte) 0);
                }
                setByte(ptr, b);
            }
        }

        public void writeData(
            final int ptr,
            final int hashCode,
            final ByteSequence key,
            final ByteSequence value)
        {
            writeInt(ptr + HASH_OFFSET, hashCode);
            final int keyLength = key.length();
            writeInt(ptr + KEY_OFFSET, keyLength);
            int dataPtr = ptr + KEY_DATA_OFFSET;
            for (int i = 0; i < keyLength; i++) {
                writeByte(dataPtr++, key.byteAt(i));
            }
            final int valueLength = value.length();
            writeInt(dataPtr, valueLength);
            dataPtr += INT_SIZE;
            try {
            for (int i = 0; i < valueLength; i++) {
                writeByte(dataPtr++, value.byteAt(i));
            }
            } catch (Throwable e) {
                System.err.println("v.l: " + value.length()
                    + ", v.bytes().l: " + value.bytes().length
                    + ", v.off: " + value.offset());
                throw e;
            }
        }

        public int copy(final DataMemory from, int ptr) {
            int newPtr = size();
            int oldPtr = ptr;
            final int retval = newPtr;

            writeInt(newPtr, from.readInt(ptr)); //hashcode

            int keyLen = from.readInt(oldPtr + KEY_OFFSET);
            writeInt(newPtr + KEY_OFFSET, keyLen);

            newPtr += KEY_DATA_OFFSET;
            oldPtr += KEY_DATA_OFFSET;

            for (int i = 0; i < keyLen; i++) {
                writeByte(newPtr++, from.getByte(oldPtr++));
            }
            int dataLen = from.readInt(oldPtr);
            writeInt(newPtr, dataLen);
            oldPtr += INT_SIZE;
            newPtr += INT_SIZE;
            for (int i = 0; i < dataLen; i++) {
                writeByte(newPtr++, from.getByte(oldPtr++));
            }
            return retval;
        }
    }

    public static void main(final String[] args) {
        ChunkedHashMap map = new ChunkedHashMap();
        for (int t = 0; t < 100; t++) {
            long start = System.currentTimeMillis();
            int maxSize = 100000;
            int added = 0, removed = 0;
            for (int i = 0; i < maxSize; i++) {
                String sKey = "KEY" + i;
                String sValue = "VALUE" + i;
                BytesRef key = new BytesRef(sKey.getBytes());
                BytesRef value = new BytesRef(sValue.getBytes());
                if (debug) {
                    System.err.println("put: " + sKey);
                }
                if (!map.put(key, value)) {
                    throw new RuntimeException(
                        "Can't put: " + sKey + " = " + sValue);
                }
                ByteSequence rValue = map.get(key);
                if (rValue == null || !rValue.toString().equals(sValue)) {
                    throw new RuntimeException(
                        "Wrong value at: " + sKey + " <> "
                            + String.valueOf(rValue));
                }
            }
            for (int i = 0; i < maxSize; i++) {
                String sKey = "KEY" + i;
                BytesRef key = new BytesRef(sKey.getBytes());
                if (debug) {
                    System.err.println("remove: " + sKey);
                }
                if (!map.remove(key)) {
                    throw new RuntimeException("Remove failed for: " + sKey);
                }
            }
            System.err.println("time: " + (System.currentTimeMillis() - start)
                + ", dataSize: " + map.data.size() + ", usage: " + map.usage
                + ", fill: " + map.fill);
            //check map is empty
            for (int i = 0; i < map.ptrTable.size(); i++) {
                if (map.ptrTable.get(i) != EMPTY) {
                    throw new RuntimeException("Non empty ptr for empty map: "
                        + map.ptrTable.get(i));
                }
            }
        }
    }
}
