package ru.yandex.solomon.codec.bits;

import java.util.Arrays;

import javax.annotation.Nonnull;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.PooledByteBufAllocator;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.solomon.memory.layout.MemoryCounter;

import static ru.yandex.solomon.codec.bits.BitArray.div8;
import static ru.yandex.solomon.codec.bits.BitArray.mod8;

/**
 * @author Vladimir Gordiychuk
 */
public class PooledHeapBitBuf extends BitBuf {
    private static final long SELF_SIZE = MemoryCounter.objectSelfSizeLayout(PooledHeapBitBuf.class);
    public static final ByteBufAllocator ALLOCATOR = PooledByteBufAllocator.DEFAULT;

    private ByteBuf buf;
    @Nonnull
    private byte[] array;
    private int offset;
    private long writeIndex;
    private long readIndex;

    public PooledHeapBitBuf(int capacity) {
        this.buf = ALLOCATOR.heapBuffer(Math.min(256, capacity));
        this.array = buf.array();
        this.offset = buf.arrayOffset();
        this.writeIndex = 0;
        this.readIndex = 0;
    }

    public PooledHeapBitBuf(BitBuf copy) {
        if (copy instanceof PooledHeapBitBuf) {
            PooledHeapBitBuf heap = (PooledHeapBitBuf) copy;
            this.buf = heap.buf.writerIndex(heap.bytesSize()).copy();
            this.array = this.buf.array();
            this.offset = this.buf.arrayOffset();
            this.writeIndex = heap.writeIndex;
            this.readIndex = heap.readIndex;
        } else {
            throw new UnsupportedOperationException("Unsupported yet");
        }
    }

    public PooledHeapBitBuf(@Nonnull ByteBuf buffer, long lengthBits) {
        this.buf = buffer;
        this.array = buffer.array();
        this.offset = buffer.arrayOffset();
        this.writeIndex = lengthBits;
        this.readIndex = 0;
    }

    public PooledHeapBitBuf(@Nonnull ByteBuf buffer, long readerIndex, long writerIndex) {
        this.buf = buffer;
        this.array = buffer.array();
        this.offset = buffer.arrayOffset();
        this.writeIndex = writerIndex;
        this.readIndex = readerIndex;
    }

    @Override
    public void ensureBytesCapacity(int minBytesCapacity) {
        if (minBytesCapacity < 0) {
            throw new IllegalArgumentException("negative minBytesCapacity: " + minBytesCapacity);
        }

        ensureBytesCapacity(div8(writeIndex), minBytesCapacity);
    }

    public void ensureBytesCapacity(int pos, int minBytesCapacity) {
        int capacity = buf.capacity() - pos;
        if (capacity <= minBytesCapacity) {
            int minRequiredCapacity = buf.capacity() + (minBytesCapacity - capacity);
            int nextExpCapacity = Math.addExact(buf.capacity(), (buf.capacity() >> 1));
            int newCapacity = Math.max(Math.max(nextExpCapacity, minRequiredCapacity), 16);
            buffer().capacity(newCapacity);
            array = buf.array();
            offset = buf.arrayOffset();
        }
    }

    @Override
    public void writeBit(boolean bit) {
        int pos = div8(writeIndex);
        int used = mod8(writeIndex);
        ensureBytesCapacity(pos, 1);
        if (bit) {
            array[offset + pos] |= 1 << used;
        } else {
            array[offset + pos] &= ~(1 << used);
        }
        ++writeIndex;
    }

    @Override
    public void write8Bits(byte bits) {
        int pos = div8(writeIndex);
        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(pos, Byte.BYTES + 1);
        int used = mod8(writeIndex);
        if (used == 0) {
            array[offset + pos] = bits;
            writeIndex += 8;
            return;
        }

        int freeBit = 8 - used;
        byte prev = (byte) ((0xff >>> freeBit) & array[offset + pos]);
        array[offset + pos] = (byte) (prev | (bits & (0xff >>> used)) << used);
        array[offset + pos + 1] = (byte) (Byte.toUnsignedInt(bits) >>> freeBit);
        writeIndex += 8;
    }

    @Override
    public void write32Bits(int bits) {
        int pos = div8(writeIndex);
        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(pos,Integer.BYTES + 2);

        int used = mod8(writeIndex);
        if (used == 0) {
            array[offset + pos] = (byte) (bits);
            array[offset + pos + 1] = (byte) (bits >>> 8);
            array[offset + pos + 2] = (byte) (bits >>> 16);
            array[offset + pos + 3] = (byte) (bits >>> 24);
            writeIndex += Integer.SIZE;
            return;
        }

        int freeBit = 8 - used;

        byte prev = (byte) ((0xff >>> freeBit) & array[offset + pos]);
        array[offset + pos] = (byte) (prev | (bits & (0xff >>> used)) << used);
        int value = bits >>> freeBit;
        array[offset + pos + 1] = (byte) (value);
        array[offset + pos + 2] = (byte) (value >>> 8);
        array[offset + pos + 3] = (byte) (value >>> 16);
        array[offset + pos + 4] = (byte) (value >>> 24);
        writeIndex += Integer.SIZE;
    }

    @Override
    public void write64Bits(long bits) {
        int pos = div8(writeIndex);
        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(pos, Long.BYTES + 2);
        int used = mod8(writeIndex);
        if (used == 0) {
            array[offset + pos] = (byte) (bits);
            array[offset + pos + 1] = (byte) (bits >>> 8);
            array[offset + pos + 2] = (byte) (bits >>> 16);
            array[offset + pos + 3] = (byte) (bits >>> 24);
            array[offset + pos + 4] = (byte) (bits >>> 32);
            array[offset + pos + 5] = (byte) (bits >>> 40);
            array[offset + pos + 6] = (byte) (bits >>> 48);
            array[offset + pos + 7] = (byte) (bits >>> 56);
            writeIndex += Long.SIZE;
            return;
        }

        int freeBit = 8 - used;
        byte prev = (byte) ((0xff >>> freeBit) & array[offset + pos]);
        array[offset + pos] = (byte) (prev | (bits & (0xff >>> used)) << used);
        long value = bits >>> freeBit;
        array[offset + pos + 1] = (byte) (value);
        array[offset + pos + 2] = (byte) (value >>> 8);
        array[offset + pos + 3] = (byte) (value >>> 16);
        array[offset + pos + 4] = (byte) (value >>> 24);
        array[offset + pos + 5] = (byte) (value >>> 32);
        array[offset + pos + 6] = (byte) (value >>> 40);
        array[offset + pos + 7] = (byte) (value >>> 48);
        array[offset + pos + 8] = (byte) (value >>> 56);
        writeIndex += Long.SIZE;
    }

    @Override
    public void writeBits(int bits, int count) {
        if (count > Integer.SIZE) {
            throw new IllegalArgumentException("cannot write more than " + Integer.SIZE + " bits: " + count);
        }

        if (Integer.SIZE - Integer.numberOfLeadingZeros(bits) > count) {
            int used = Integer.SIZE - Integer.numberOfLeadingZeros(bits);
            throw new IllegalArgumentException("not able write " + count + " bits for " + bits + " because used " + used + " bits");
        }

        int pos = div8(writeIndex);
        int used = mod8(writeIndex);
        int freeBit = 8 - used;

        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(pos, Integer.BYTES + 2);
        byte prev = (byte) ((0xff >>> freeBit) & array[offset + pos]);
        array[offset + pos] = (byte) (prev | (bits & (0xff >>> used)) << used);

        for (long value = bits >>> freeBit; value != 0; value >>>= 8) {
            ++pos;
            array[offset + pos] = (byte) value;
        }

        writeIndex += count;
    }

    @Override
    public void writeBits(long bits, int count) {
        if (count > Long.SIZE) {
            throw new IllegalArgumentException("cannot write more than " + Long.SIZE + " bits: " + count);
        }

        if (Long.SIZE - Long.numberOfLeadingZeros(bits) > count) {
            int used = Long.SIZE - Long.numberOfLeadingZeros(bits);
            throw new IllegalArgumentException("not able write " + count + " bits for " + bits + " because used " + used + " bits");
        }

        int pos = div8(writeIndex);
        int used = mod8(writeIndex);
        int freeBit = 8 - used;

        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(pos, Long.BYTES + 2);
        byte prev = (byte) ((0xff >>> freeBit) & array[offset + pos]);
        array[offset + pos] = (byte) (prev | (bits & (0xff >>> used)) << used);

        for (long value = bits >>> freeBit; value != 0; value >>>= 8) {
            ++pos;
            array[offset + pos] = (byte) value;
        }

        writeIndex += count;
    }

    @Override
    public void writeIntVarint1N(int n, int max) {
        if (n < 0 || n > max || max > 8) {
            throw new IllegalArgumentException("incorrect n=" + n + ", max=" + max);
        }

        int pos = div8(writeIndex);
        ensureBytesCapacity(pos, 2);
        if (n == max) {
            writeMask(pos, n, (byte) (0xff >>> (8 - n)));
            return;
        }

        writeMask(pos, n + 1, (byte) (0xff >>> (8 - n)));
    }

    @Override
    public void writeBits(BitBuf src, long srcIndex, long length) {
        if (mod8(srcIndex) != 0 ) {
            throw new UnsupportedOperationException("Unsupported writeBits not alighted to byte pos: " + srcIndex);
        }

        if (mod8(writeIndex) != 0) {
            throw new UnsupportedOperationException("Unsupported writeBits to not alighted pos: " + writeIndex);
        }

        ensureBytesCapacity(div8(length) + 1);
        if (src instanceof PooledHeapBitBuf) {
            var that = (PooledHeapBitBuf) src;
            this.buffer().writeBytes(that.buffer(), div8(srcIndex), BitArray.arrayLengthForBits(length));
            writeIndex += length;
            return;
        }

        System.arraycopy(src.array(), div8(srcIndex), array, offset + div8(writeIndex), BitArray.arrayLengthForBits(length));
        writeIndex += length;
    }

    @Override
    public void alignToByte() {
        int pos = div8(writeIndex);
        int used = mod8(writeIndex);
        int freeBit = 8 - used;
        byte prev = (byte) ((0xff >>> freeBit) & array[offset + pos]);
        array[offset + pos] = prev;
        writeIndex = (writeIndex + 7) / 8 * 8;
    }

    @Override
    public int bytesSize() {
        return BitArray.arrayLengthForBits(readableBits());
    }

    @Override
    public long readableBits() {
        return writeIndex - readIndex;
    }

    @Override
    public boolean readBit() {
        checkReadable(1);
        long pos = readIndex++;
        return BitArray.isBitSet(array[offset + div8(pos)], mod8(pos));
    }

    @Override
    public int readBitsToInt(int bitCount) {
        if (bitCount > 32) {
            throw new IllegalArgumentException("bit count cannot be greater than 32: " + bitCount);
        }
        return (int) readBitsToLong(bitCount);
    }

    @Override
    public long readBitsToLong(int bitCount) {
        if (bitCount == 0) {
            return 0;
        } else if (bitCount > 64) {
            throw new IllegalArgumentException("bit count cannot be greater than 64: " + bitCount);
        }
        checkReadable(bitCount);

        long r = 0;
        int pos = div8(readIndex);
        int used = mod8(readIndex);
        int currentBit = 0;
        readIndex += bitCount;

        // align to byte
        {
            int freeBit = 8 - used;
            int bitsFromFirstByte = Math.min(bitCount, freeBit);
            r |= ((array[offset + pos] & 0xffL) >>> used) & ((1L << bitsFromFirstByte) - 1);
            if (bitCount <= freeBit) {
                return r;
            }
            currentBit += bitsFromFirstByte;
        }

        // read whole bytes
        while (bitCount - currentBit >= 8) {
            r |= (array[offset + ++pos] & 0xffL) << currentBit;
            currentBit += 8;
        }

        // read remaining bits
        {
            int bitsFromLastByte = bitCount - currentBit;

            // `if` is a protection againts buffer overrun
            if (bitsFromLastByte != 0) {
                r |= (array[offset + ++pos] & ((1L << bitsFromLastByte) - 1)) << currentBit;
            }
        }

        return r;
    }

    @Override
    public byte read8Bits() {
        checkReadable(Byte.SIZE);
        byte r;
        if (mod8(readIndex) == 0) {
            r = array[offset + div8(readIndex)];
        } else {
            int x = Byte.toUnsignedInt(array[offset + div8(readIndex)]) | (Byte.toUnsignedInt(array[offset + div8(readIndex) + 1]) << 8);
            r = (byte) (x >>> (mod8(readIndex)));
        }
        readIndex += Byte.SIZE;
        return r;
    }

    @Override
    public int read32Bits() {
        checkReadable(Integer.SIZE);
        int r;
        if (mod8(readIndex) == 0) {
            r = 0;
            for (int i = 0; i < Integer.BYTES; ++i) {
                r |= Byte.toUnsignedInt(array[offset + div8(readIndex) + i]) << (i * 8);
            }
        } else {
            r = Byte.toUnsignedInt(array[offset + div8(readIndex)]) >>> mod8(readIndex);
            for (int i = 1; i < Integer.BYTES + 1; ++i) {
                r |= Byte.toUnsignedInt(array[offset + div8(readIndex) + i]) << (i * 8 - mod8(readIndex));
            }
        }
        readIndex += Integer.SIZE;
        return r;
    }

    @Override
    public long read64Bits() {
        checkReadable(Long.SIZE);
        long r;
        if (mod8(readIndex) == 0) {
            r = 0;
            for (int i = 0; i < Long.BYTES; ++i) {
                r |= Byte.toUnsignedLong(array[offset + div8(readIndex) + i]) << (i * 8);
            }
        } else {
            r = Byte.toUnsignedLong(array[offset + div8(readIndex)]) >>> mod8(readIndex);
            for (int i = 1; i < Long.BYTES + 1; ++i) {
                r |= Byte.toUnsignedLong(array[offset + div8(readIndex) + i]) << (i * 8 - mod8(readIndex));
            }
        }
        readIndex += Long.SIZE;
        return r;
    }

    @Override
    public int readIntVarint1N(int max) {
        if (max > Byte.SIZE) {
            throw new IllegalArgumentException("max " + max + " > 8");
        }

        int r = 0;

        // TODO: can do faster
        while (r != max && readBit()) {
            ++r;
        }
        return r;
    }

    @Override
    public int readIntVarint8() {
        int result = 0;
        for (int shift = 0; shift < 64; shift += 7) {
            final byte b = read8Bits();
            result |= (long) (b & 0x7F) << shift;
            if ((b & 0x80) == 0) {
                return result;
            }
        }
        throw new IllegalStateException("malformed");
    }

    @Override
    public long readLongVarint8() {
        long result = 0;
        for (int shift = 0; shift < 64; shift += 7) {
            final byte b = read8Bits();
            result |= (long) (b & 0x7F) << shift;
            if ((b & 0x80) == 0) {
                return result;
            }
        }
        throw new IllegalStateException("malformed");
    }

    @Override
    public void resetWriterIndex() {
        writeIndex = 0;
    }

    @Override
    public void resetReadIndex() {
        readIndex = 0;
    }

    @Override
    public long writerIndex() {
        return writeIndex;
    }

    @Override
    public void writerIndex(long writerIndex) {
        if (div8(writerIndex) > buf.capacity()) {
            throw new IndexOutOfBoundsException("capacity: " + buf.capacity() + " index "+ div8(writerIndex));
        }
        this.writeIndex = writerIndex;
    }

    @Override
    public long readerIndex() {
        return readIndex;
    }

    @Override
    public void readerIndex(long readerIndex) {
        if (readerIndex > writeIndex) {
            throw new IndexOutOfBoundsException("ridx(" + readerIndex + ") >= widx(" + writeIndex + ")");
        }
        this.readIndex = readerIndex;
    }

    @Override
    public void skipBits(long length) {
        long next = readIndex + length;
        if (next > writeIndex) {
            throw new IndexOutOfBoundsException("ridx(" + next + ") >= widx(" + writeIndex + ")");
        }
        this.readIndex = next;
    }

    @Override
    public BitBuf asReadOnly() {
        return new ReadOnlyPooledHeapBitBuf(buffer().duplicate(), readIndex, writeIndex);
    }

    @Override
    public BitBuf slice(long index, long length) {
        if (mod8(index) != 0) {
            throw new UnsupportedOperationException("Unsupported slice not alighted to byte pos: " + index);
        }

        if (index + length > writeIndex) {
            throw new IndexOutOfBoundsException((index + length) + " >= widx(" + writeIndex + ")");
        }

        return new PooledHeapBitBuf(buffer(), index, index + length);
    }

    @Override
    public BitBuf duplicate() {
        if (this instanceof ReadOnlyPooledHeapBitBuf) {
            return new ReadOnlyPooledHeapBitBuf(buffer(), readIndex, writeIndex);
        }
        return new PooledHeapBitBuf(buffer(), readIndex, writeIndex);
    }

    @Override
    public BitBuf copy() {
        return copy(readIndex, readableBits());
    }

    @Override
    public BitBuf copy(long index, long length) {
        if (mod8(index) != 0) {
            throw new UnsupportedOperationException("Unsupported slice not alighted to byte pos: " + index);
        }

        return new PooledHeapBitBuf(buffer().copy(div8(index), BitArray.arrayLengthForBits(length)), length);
    }

    @Override
    public byte[] array() {
        return array;
    }

    @Override
    public int arrayOffset() {
        return offset;
    }

    @Override
    public int refCnt() {
        return buf.refCnt();
    }

    @Override
    public PooledHeapBitBuf retain() {
        buf.retain();
        return this;
    }

    @Override
    public PooledHeapBitBuf retain(int increment) {
        buf.retain(increment);
        return this;
    }

    @Override
    public PooledHeapBitBuf touch() {
        buf.touch();
        return this;
    }

    @Override
    public PooledHeapBitBuf touch(Object hint) {
        buf.touch(hint);
        return this;
    }

    @Override
    public BitBuf allocate(int byteCapacity) {
        return new PooledHeapBitBuf(buf.alloc().heapBuffer(byteCapacity), 0, 0);
    }

    @Override
    public boolean isDirect() {
        return buf.isDirect();
    }

    @Override
    public boolean release() {
        return buf.release();
    }

    @Override
    public boolean release(int decrement) {
        return buf.release(decrement);
    }

    private void checkReadable(long expect) {
        if (expect > readableBits()) {
            throw new IllegalArgumentException("expected at least: " + expect + ", remaining: " + readableBits());
        }
    }

    private void writeMask(int pos, int size, byte mask) {
        int used = mod8(writeIndex);
        int freeBit = 8 - used;
        // TODO: uncomment it as only fixed dirty write
        //array[offset + pos] |= mask << used;
        array[offset + pos] = (byte) ((0xff >>> freeBit) & array[offset + pos] | (mask << used));
        if (freeBit < size) {
            array[offset + pos + 1] = (byte) (mask >>> freeBit);
        }
        writeIndex += size;
    }

    @Override
    public long memorySizeIncludingSelf() {
        return SELF_SIZE + MemoryCounter.byteBufSize(buf);
    }

    @Override
    @SuppressWarnings("EqualsWrongThing")
    public boolean equals(Object o) {
        if (this == o) return true;
        if (!(o instanceof BitBuf)) return false;

        if (!(o instanceof PooledHeapBitBuf)) {
            BitBuf that = (BitBuf) o;
            if (readableBits() != that.readableBits()) {
                return false;
            }

            return Arrays.equals(this.array(), this.arrayOffset(), this.arrayOffset() + this.bytesSize(), that.array(), that.arrayOffset(), that.arrayOffset() + that.bytesSize());
        }

        PooledHeapBitBuf that = (PooledHeapBitBuf) o;

        if (that.writeIndex != writeIndex) {
            return false;
        }

        int fullBytes = div8(writeIndex);
        if (!ByteBufUtil.equals(this.buffer(), 0, that.buffer(), 0, fullBytes)) {
            return false;
        }

        if (mod8(writeIndex) != 0) {
            int a = this.array[offset + fullBytes] & ((1 << mod8(writeIndex)) - 1);
            int b = that.array[that.offset + fullBytes] & ((1 << mod8(writeIndex)) - 1);
            return a == b;
        }
        return true;
    }

    @Override
    public int hashCode() {
        int hash = Long.hashCode(writeIndex);
        return 31 * hash + Cf.ByteArray.hashCodeOfRange(array, offset, offset + writerByteIndex());
    }

    @Override
    public String toString() {
        if (writeIndex == 0) {
            return "[]";
        }

        StringBuilder sb = new StringBuilder();
        sb.append("[bits: ").append(writeIndex).append(": ");

        for (int index = 0; index < Math.min(writeIndex, 100 * 8); index++) {
            if (BitArray.isBitSet(array[offset + div8(index)], mod8(index))) {
                sb.append("1");
            } else {
                sb.append("0");
            }

            if ((index + 1) % 8 == 0) {
                sb.append(" ");
            }
        }
        if (bytesSize() > 100) {
            sb.append("...");
        }
        sb.append("]");
        return sb.toString();
    }

    public ByteBuf buffer() {
        return buf.writerIndex(writerByteIndex());
    }

    public int writerByteIndex() {
        return BitArray.arrayLengthForBits(writeIndex);
    }
}
