package ru.yandex.solomon.codec.bits;

import java.util.Arrays;

import javax.annotation.Nonnull;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.solomon.memory.layout.MemoryCounter;

import static ru.yandex.solomon.codec.bits.BitArray.div8;
import static ru.yandex.solomon.codec.bits.BitArray.mod8;

/**
 * @author Vladimir Gordiychuk
 */
public class HeapBitBuf extends BitBuf {
    private static final long SELF_SIZE = MemoryCounter.objectSelfSizeLayout(HeapBitBuf.class);

    @Nonnull
    private byte[] array;
    private long writeIndex;
    private long readIndex;

    public HeapBitBuf() {
        this.array = new byte[16];
        this.writeIndex = 0;
    }

    public HeapBitBuf(BitBuf copy) {
        if (copy instanceof HeapBitBuf) {
            HeapBitBuf heap = (HeapBitBuf) copy;
            this.array = heap.array.clone();
            this.writeIndex = heap.writeIndex;
            this.readIndex = heap.readIndex;
        } else {
            throw new UnsupportedOperationException("Unsupported yet");
        }
    }

    public HeapBitBuf(@Nonnull byte[] array, long lengthBits) {
        BitArray.checkSize(array, lengthBits);
        this.array = array;
        this.writeIndex = lengthBits;
        this.readIndex = 0;
    }

    public HeapBitBuf(@Nonnull byte[] array, long readerIndex, long writerIndex) {
        BitArray.checkSize(array, writerIndex);
        if (readerIndex > writerIndex) {
            throw new IllegalArgumentException("readerIdx " + readerIndex + " > writerIdx " + writerIndex);
        }
        this.array = array;
        this.writeIndex = writerIndex;
        this.readIndex = readerIndex;
    }

    @Override
    public void ensureBytesCapacity(int minBytesCapacity) {
        if (minBytesCapacity < 0) {
            throw new IllegalArgumentException("negative minBytesCapacity: " + minBytesCapacity);
        }

        ensureBytesCapacity(div8(writeIndex), minBytesCapacity);
    }

    public void ensureBytesCapacity(int pos, int minBytesCapacity) {
        int capacity = array.length - pos;
        if (capacity <= minBytesCapacity) {
            int minRequiredCapacity = array.length + (minBytesCapacity - capacity);
            int nextExpCapacity = Math.addExact(array.length, (array.length >> 1));
            int newCapacity = Math.max(Math.max(nextExpCapacity, minRequiredCapacity), 16);
            array = Arrays.copyOf(array, newCapacity);
        }
    }

    @Override
    public void writeBit(boolean bit) {
        int pos = div8(writeIndex);
        int used = mod8(writeIndex);
        ensureBytesCapacity(pos, 1);
        if (bit) {
            array[pos] |= 1 << used;
        } else {
            array[pos] &= ~(1 << used);
        }
        ++writeIndex;
    }

    @Override
    public void write8Bits(byte bits) {
        int pos = div8(writeIndex);
        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(pos, Byte.BYTES + 1);
        int used = mod8(writeIndex);
        if (used == 0) {
            array[pos] = bits;
            writeIndex += 8;
            return;
        }

        int freeBit = 8 - used;
        byte prev = (byte) ((0xff >>> freeBit) & array[pos]);
        array[pos] = (byte) (prev | (bits & (0xff >>> used)) << used);
        array[pos + 1] = (byte) (Byte.toUnsignedInt(bits) >>> freeBit);
        writeIndex += 8;
    }

    @Override
    public void write32Bits(int bits) {
        int pos = div8(writeIndex);
        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(pos,Integer.BYTES + 2);

        int used = mod8(writeIndex);
        if (used == 0) {
            array[pos] = (byte) (bits);
            array[pos + 1] = (byte) (bits >>> 8);
            array[pos + 2] = (byte) (bits >>> 16);
            array[pos + 3] = (byte) (bits >>> 24);
            writeIndex += Integer.SIZE;
            return;
        }

        int freeBit = 8 - used;

        byte prev = (byte) ((0xff >>> freeBit) & array[pos]);
        array[pos] = (byte) (prev | (bits & (0xff >>> used)) << used);
        int value = bits >>> freeBit;
        array[pos + 1] = (byte) (value);
        array[pos + 2] = (byte) (value >>> 8);
        array[pos + 3] = (byte) (value >>> 16);
        array[pos + 4] = (byte) (value >>> 24);
        writeIndex += Integer.SIZE;
    }

    @Override
    public void write64Bits(long bits) {
        int pos = div8(writeIndex);
        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(pos, Long.BYTES + 2);
        int used = mod8(writeIndex);
        if (used == 0) {
            array[pos] = (byte) (bits);
            array[pos + 1] = (byte) (bits >>> 8);
            array[pos + 2] = (byte) (bits >>> 16);
            array[pos + 3] = (byte) (bits >>> 24);
            array[pos + 4] = (byte) (bits >>> 32);
            array[pos + 5] = (byte) (bits >>> 40);
            array[pos + 6] = (byte) (bits >>> 48);
            array[pos + 7] = (byte) (bits >>> 56);
            writeIndex += Long.SIZE;
            return;
        }

        int freeBit = 8 - used;
        byte prev = (byte) ((0xff >>> freeBit) & array[pos]);
        array[pos] = (byte) (prev | (bits & (0xff >>> used)) << used);
        long value = bits >>> freeBit;
        array[pos + 1] = (byte) (value);
        array[pos + 2] = (byte) (value >>> 8);
        array[pos + 3] = (byte) (value >>> 16);
        array[pos + 4] = (byte) (value >>> 24);
        array[pos + 5] = (byte) (value >>> 32);
        array[pos + 6] = (byte) (value >>> 40);
        array[pos + 7] = (byte) (value >>> 48);
        array[pos + 8] = (byte) (value >>> 56);
        writeIndex += Long.SIZE;
    }

    @Override
    public void writeBits(int bits, int count) {
        if (count > Integer.SIZE) {
            throw new IllegalArgumentException("cannot write more than " + Integer.SIZE + " bits: " + count);
        }

        if (Integer.SIZE - Integer.numberOfLeadingZeros(bits) > count) {
            int used = Integer.SIZE - Integer.numberOfLeadingZeros(bits);
            throw new IllegalArgumentException("not able write " + count + " bits for " + bits + " because used " + used + " bits");
        }

        int pos = div8(writeIndex);
        int used = mod8(writeIndex);
        int freeBit = 8 - used;

        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(pos, Integer.BYTES + 2);
        byte prev = (byte) ((0xff >>> freeBit) & array[pos]);
        array[pos] = (byte) (prev | (bits & (0xff >>> used)) << used);

        for (long value = bits >>> freeBit; value != 0; value >>>= 8) {
            array[++pos] = (byte) value;
        }

        writeIndex += count;
    }

    @Override
    public void writeBits(long bits, int count) {
        if (count > Long.SIZE) {
            throw new IllegalArgumentException("cannot write more than " + Long.SIZE + " bits: " + count);
        }

        if (Long.SIZE - Long.numberOfLeadingZeros(bits) > count) {
            int used = Long.SIZE - Long.numberOfLeadingZeros(bits);
            throw new IllegalArgumentException("not able write " + count + " bits for " + bits + " because used " + used + " bits");
        }

        int pos = div8(writeIndex);
        int used = mod8(writeIndex);
        int freeBit = 8 - used;

        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(pos, Long.BYTES + 2);
        byte prev = (byte) ((0xff >>> freeBit) & array[pos]);
        array[pos] = (byte) (prev | (bits & (0xff >>> used)) << used);

        for (long value = bits >>> freeBit; value != 0; value >>>= 8) {
            array[++pos] = (byte) value;
        }

        writeIndex += count;
    }

    @Override
    public void writeIntVarint1N(int n, int max) {
        if (n < 0 || n > max || max > 8) {
            throw new IllegalArgumentException("incorrect n=" + n + ", max=" + max);
        }

        int pos = div8(writeIndex);
        ensureBytesCapacity(pos, 2);
        if (n == max) {
            writeMask(pos, n, (byte) (0xff >>> (8 - n)));
            return;
        }

        writeMask(pos, n + 1, (byte) (0xff >>> (8 - n)));
    }

    @Override
    public void writeBits(BitBuf src, long srcIndex, long length) {
        if (mod8(srcIndex) != 0 ) {
            throw new UnsupportedOperationException("Unsupported writeBits not alighted to byte pos: " + srcIndex);
        }

        if (mod8(writeIndex) != 0) {
            throw new UnsupportedOperationException("Unsupported writeBits to not alighted pos: " + writeIndex);
        }

        ensureBytesCapacity(div8(length) + 1);
        System.arraycopy(src.array(), div8(srcIndex), array, div8(writeIndex), BitArray.arrayLengthForBits(length));
        writeIndex += length;
    }

    @Override
    public void alignToByte() {
        writeIndex = (writeIndex + 7) / 8 * 8;
    }

    @Override
    public int bytesSize() {
        return BitArray.arrayLengthForBits(readableBits());
    }

    @Override
    public long readableBits() {
        return writeIndex - readIndex;
    }

    @Override
    public boolean readBit() {
        checkReadable(1);
        return BitArray.isBitSet(array, readIndex++);
    }

    @Override
    public int readBitsToInt(int bitCount) {
        if (bitCount > 32) {
            throw new IllegalArgumentException("bit count cannot be greater than 32: " + bitCount);
        }
        return (int) readBitsToLong(bitCount);
    }

    @Override
    public long readBitsToLong(int bitCount) {
        if (bitCount == 0) {
            return 0;
        } else if (bitCount > 64) {
            throw new IllegalArgumentException("bit count cannot be greater than 64: " + bitCount);
        }
        checkReadable(bitCount);

        long r = 0;
        int pos = div8(readIndex);
        int used = mod8(readIndex);
        int currentBit = 0;
        readIndex += bitCount;

        // align to byte
        {
            int freeBit = 8 - used;
            int bitsFromFirstByte = Math.min(bitCount, freeBit);
            r |= ((array[pos] & 0xffL) >>> used) & ((1L << bitsFromFirstByte) - 1);
            if (bitCount <= freeBit) {
                return r;
            }
            currentBit += bitsFromFirstByte;
        }

        // read whole bytes
        while (bitCount - currentBit >= 8) {
            r |= (array[++pos] & 0xffL) << currentBit;
            currentBit += 8;
        }

        // read remaining bits
        {
            int bitsFromLastByte = bitCount - currentBit;

            // `if` is a protection againts buffer overrun
            if (bitsFromLastByte != 0) {
                r |= (array[++pos] & ((1L << bitsFromLastByte) - 1)) << currentBit;
            }
        }

        return r;
    }

    @Override
    public byte read8Bits() {
        checkReadable(Byte.SIZE);
        byte r;
        if (mod8(readIndex) == 0) {
            r = array[div8(readIndex)];
        } else {
            int x = Byte.toUnsignedInt(array[div8(readIndex)]) | (Byte.toUnsignedInt(array[div8(readIndex) + 1]) << 8);
            r = (byte) (x >>> (mod8(readIndex)));
        }
        readIndex += Byte.SIZE;
        return r;
    }

    @Override
    public int read32Bits() {
        checkReadable(Integer.SIZE);
        int r;
        if (mod8(readIndex) == 0) {
            r = 0;
            for (int i = 0; i < Integer.BYTES; ++i) {
                r |= Byte.toUnsignedInt(array[div8(readIndex) + i]) << (i * 8);
            }
        } else {
            r = Byte.toUnsignedInt(array[div8(readIndex)]) >>> mod8(readIndex);
            for (int i = 1; i < Integer.BYTES + 1; ++i) {
                r |= Byte.toUnsignedInt(array[div8(readIndex) + i]) << (i * 8 - mod8(readIndex));
            }
        }
        readIndex += Integer.SIZE;
        return r;
    }

    @Override
    public long read64Bits() {
        checkReadable(Long.SIZE);
        long r;
        if (mod8(readIndex) == 0) {
            r = 0;
            for (int i = 0; i < Long.BYTES; ++i) {
                r |= Byte.toUnsignedLong(array[div8(readIndex) + i]) << (i * 8);
            }
        } else {
            r = Byte.toUnsignedLong(array[div8(readIndex)]) >>> mod8(readIndex);
            for (int i = 1; i < Long.BYTES + 1; ++i) {
                r |= Byte.toUnsignedLong(array[div8(readIndex) + i]) << (i * 8 - mod8(readIndex));
            }
        }
        readIndex += Long.SIZE;
        return r;
    }

    @Override
    public int readIntVarint1N(int max) {
        if (max > Byte.SIZE) {
            throw new IllegalArgumentException("max " + max + " > 8");
        }

        int r = 0;

        // TODO: can do faster
        while (r != max && readBit()) {
            ++r;
        }
        return r;
    }

    @Override
    public int readIntVarint8() {
        int result = 0;
        for (int shift = 0; shift < 64; shift += 7) {
            final byte b = read8Bits();
            result |= (long) (b & 0x7F) << shift;
            if ((b & 0x80) == 0) {
                return result;
            }
        }
        throw new IllegalStateException("malformed");
    }

    @Override
    public long readLongVarint8() {
        long result = 0;
        for (int shift = 0; shift < 64; shift += 7) {
            final byte b = read8Bits();
            result |= (long) (b & 0x7F) << shift;
            if ((b & 0x80) == 0) {
                return result;
            }
        }
        throw new IllegalStateException("malformed");
    }

    @Override
    public void resetWriterIndex() {
        writeIndex = 0;
    }

    @Override
    public void resetReadIndex() {
        readIndex = 0;
    }

    @Override
    public long writerIndex() {
        return writeIndex;
    }

    @Override
    public void writerIndex(long writerIndex) {
        if (div8(writerIndex) > array.length) {
            throw new IndexOutOfBoundsException("capacity: " + array.length + " index "+ div8(writerIndex));
        }
        this.writeIndex = writerIndex;
    }

    @Override
    public long readerIndex() {
        return readIndex;
    }

    @Override
    public void readerIndex(long readerIndex) {
        if (readerIndex > writeIndex) {
            throw new IndexOutOfBoundsException("ridx(" + readerIndex + ") >= widx(" + writeIndex + ")");
        }
        this.readIndex = readerIndex;
    }

    @Override
    public void skipBits(long length) {
        long next = readIndex + length;
        if (next > writeIndex) {
            throw new IndexOutOfBoundsException("ridx(" + next + ") >= widx(" + writeIndex + ")");
        }
        this.readIndex = next;
    }

    @Override
    public BitBuf asReadOnly() {
        return new ReadOnlyHeapBitBuf(array, readIndex, writeIndex);
    }

    @Override
    public BitBuf slice(long index, long length) {
        if (mod8(index) != 0) {
            throw new UnsupportedOperationException("Unsupported slice not alighted to byte pos: " + index);
        }

        if (index + length > writeIndex) {
            throw new IndexOutOfBoundsException((index + length) + " >= widx(" + writeIndex + ")");
        }

        HeapBitBuf result = new HeapBitBuf(array, length);
        result.writeIndex = index + length;
        result.readIndex = index;
        return result;
    }

    @Override
    public BitBuf duplicate() {
        if (this instanceof ReadOnlyHeapBitBuf) {
            return new ReadOnlyHeapBitBuf(array, readIndex, writeIndex);
        }

        return new HeapBitBuf(array, readIndex, writeIndex);
    }

    @Override
    public BitBuf copy() {
        return copy(readIndex, readableBits());
    }

    @Override
    public BitBuf copy(long index, long length) {
        if (mod8(index) != 0) {
            throw new UnsupportedOperationException("Unsupported slice not alighted to byte pos: " + index);
        }

        int size = BitArray.arrayLengthForBits(length);
        int from = div8(index);
        byte[] bytes = Arrays.copyOfRange(array, from, from + size);
        return new HeapBitBuf(bytes, length);
    }

    @Override
    public byte[] array() {
        return array;
    }

    @Override
    public int arrayOffset() {
        return 0;
    }

    @Override
    public int refCnt() {
        return 1;
    }

    @Override
    public HeapBitBuf retain() {
        return this;
    }

    @Override
    public HeapBitBuf retain(int increment) {
        return this;
    }

    @Override
    public HeapBitBuf touch() {
        return this;
    }

    @Override
    public HeapBitBuf touch(Object hint) {
        return this;
    }

    @Override
    public BitBuf allocate(int byteCapacity) {
        return new HeapBitBuf(new byte[byteCapacity], 0);
    }

    @Override
    public boolean isDirect() {
        return false;
    }

    @Override
    public boolean release() {
        return false;
    }

    @Override
    public boolean release(int decrement) {
        return false;
    }

    private void checkReadable(long expect) {
        if (expect > readableBits()) {
            throw new IllegalArgumentException("expected at least: " + expect + ", remaining: " + readableBits());
        }
    }

    private void writeMask(int pos, int size, byte mask) {
        int used = mod8(writeIndex);
        int freeBit = 8 - used;
        // TODO: uncomment it as only fixed dirty write
        //array[pos] |= mask << used;
        array[pos] = (byte) ((0xff >>> freeBit) & array[pos] | (mask << used));
        if (freeBit < size) {
            array[pos + 1] = (byte) (mask >>> freeBit);
        }
        writeIndex += size;
    }

    @Override
    public long memorySizeIncludingSelf() {
        return SELF_SIZE + MemoryCounter.arrayObjectSize(array);
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (!(o instanceof BitBuf)) return false;

        if (!(o instanceof HeapBitBuf)) {
            BitBuf that = (BitBuf) o;
            if (readableBits() != that.readableBits()) {
                return false;
            }

            return Arrays.equals(array, 0, bytesSize(), that.array(), that.arrayOffset(), that.arrayOffset() + that.bytesSize());
        }

        HeapBitBuf that = (HeapBitBuf) o;

        if (that.writeIndex != writeIndex) {
            return false;
        }

        int fullBytes = div8(writeIndex);
        if (!Cf.ByteArray.equalPrefixes(this.array, that.array, fullBytes)) {
            return false;
        }

        if (mod8(writeIndex) != 0) {
            int a = this.array[fullBytes] & ((1 << mod8(writeIndex)) - 1);
            int b = that.array[fullBytes] & ((1 << mod8(writeIndex)) - 1);
            return a == b;
        }
        return true;
    }

    @Override
    public int hashCode() {
        int hash = Long.hashCode(writeIndex);
        return 31 * hash + Cf.ByteArray.hashCodeOfRange(array, 0, bytesSize());
    }

    @Override
    public String toString() {
        if (writeIndex == 0) {
            return "[]";
        }

        StringBuilder sb = new StringBuilder();
        sb.append("[bits: ").append(writeIndex).append(": ");

        for (int index = 0; index < Math.min(writeIndex, 100 * 8); index++) {
            if (BitArray.isBitSet(array, index)) {
                sb.append("1");
            } else {
                sb.append("0");
            }

            if ((index + 1) % 8 == 0) {
                sb.append(" ");
            }
        }
        if (bytesSize() > 100) {
            sb.append("...");
        }
        sb.append("]");
        return sb.toString();
    }
}
