package ru.yandex.solomon.codec.bits;

import java.util.Arrays;

import javax.annotation.Nonnull;

import io.netty.buffer.AbstractByteBuf;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.util.internal.StringUtil;

import ru.yandex.solomon.memory.layout.MemoryCounter;

import static ru.yandex.solomon.codec.bits.BitArray.div8;
import static ru.yandex.solomon.codec.bits.BitArray.div8Exact;
import static ru.yandex.solomon.codec.bits.BitArray.mod8;

/**
 * Used little endian bits order
 *
 * @author Vladimir Gordiychuk
 */
public class NettyBitBuf extends BitBuf {
    private static final long SELF_SIZE = MemoryCounter.objectSelfSizeLayout(NettyBitBuf.class);
    private static final long BYTEBUF_SIZE = MemoryCounter.objectSelfSizeLayout(AbstractByteBuf.class);

    @Nonnull
    private ByteBuf buffer;
    private int writeIndexBits;
    private int readIndexBits;

    public NettyBitBuf(@Nonnull ByteBuf buffer, long lengthBits) {
        this.buffer = buffer.writerIndex(div8Exact(lengthBits));
        this.writeIndexBits = (byte) mod8(lengthBits);
    }

    private NettyBitBuf(@Nonnull ByteBuf buffer, int writeIndexBits, int readIndexBits) {
        this.buffer = buffer;
        this.writeIndexBits = writeIndexBits;
        this.readIndexBits = readIndexBits;
    }

    @Override
    public void ensureBytesCapacity(int minBytesCapacity) {
        buffer.ensureWritable(minBytesCapacity);
    }

    @Override
    public void writeBit(boolean bit) {
        ensureBytesCapacity(Byte.BYTES);
        int pos = buffer.writerIndex();
        byte prev = buffer.getByte(pos);
        if (bit) {
            prev |= 1 << writeIndexBits;
        } else {
            prev &= ~(1 << writeIndexBits);
        }
        buffer.setByte(pos, prev);
        if (++writeIndexBits == Byte.SIZE) {
            writeIndexBits = 0;
            buffer.writerIndex(pos + 1);
        }
    }

    @Override
    public void write8Bits(byte bits) {
        if (writeIndexBits == 0) {
            buffer.writeByte(bits);
            return;
        }

        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(Byte.BYTES + 1);

        int pos = buffer.writerIndex();
        int used = writeIndexBits;
        int freeBit = 8 - used;

        byte prev = (byte) ((0xff >>> freeBit) & buffer.getByte(pos));
        buffer.writeByte((byte) (prev | (bits & (0xff >>> used)) << used));
        buffer.setByte(pos + 1, (byte) (Byte.toUnsignedInt(bits) >>> freeBit));
    }

    @Override
    public void write32Bits(int bits) {
        if (writeIndexBits == 0) {
            buffer.writeIntLE(bits);
            return;
        }
        int pos = buffer.writerIndex();
        int used = writeIndexBits;
        int freeBit = 8 - used;
        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(Integer.BYTES + 1);
        byte prev = (byte) ((0xff >>> freeBit) & buffer.getByte(pos));
        buffer.setByte(pos, (byte) (prev | (bits & (0xff >>> used)) << used));
        int value = bits >>> freeBit;
        buffer.setIntLE(pos + 1, value);
        buffer.writerIndex(pos + Integer.BYTES);
    }

    @Override
    public void write64Bits(long bits) {
        if (writeIndexBits == 0) {
            buffer.writeLongLE(bits);
            return;
        }
        int pos = buffer.writerIndex();
        int used = writeIndexBits;
        int freeBit = 8 - used;
        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(Long.BYTES + 1);
        byte prev = (byte) ((0xff >>> freeBit) & buffer.getByte(pos));
        buffer.setByte(pos, (byte) (prev | (bits & (0xff >>> used)) << used));
        long value = bits >>> freeBit;
        buffer.setLongLE(pos + 1, value);
        buffer.writerIndex(pos + Long.BYTES);
    }

    @Override
    public void writeBits(int bits, int count) {
        if (count > Integer.SIZE) {
            throw new IllegalArgumentException("cannot write more than " + Integer.SIZE + " bits: " + count);
        }

        if (Integer.SIZE - Integer.numberOfLeadingZeros(bits) > count) {
            int used = Integer.SIZE - Integer.numberOfLeadingZeros(bits);
            throw new IllegalArgumentException("not able write " + count + " bits for " + bits + " because used " + used + " bits");
        }

        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(Integer.BYTES + 1);
        int pos = buffer.writerIndex();
        int used = writeIndexBits;
        if (used == 0) {
            buffer.setIntLE(pos, bits);
            buffer.writerIndex(pos + div8(count));
            writeIndexBits = mod8(count);
            return;
        }

        int freeBit = 8 - used;
        byte prev = (byte) ((0xff >>> freeBit) & buffer.getByte(pos));
        buffer.writeByte((byte) (prev | (bits & (0xff >>> used)) << used));
        int value = bits >>> freeBit;
        buffer.setIntLE(pos + 1, value);
        buffer.writerIndex(pos + div8(used + count));
        writeIndexBits = mod8(used + count);
    }

    @Override
    public void writeBits(long bits, int count) {
        if (count > Long.SIZE) {
            throw new IllegalArgumentException("cannot write more than " + Long.SIZE + " bits: " + count);
        }

        if (Long.SIZE - Long.numberOfLeadingZeros(bits) > count) {
            int used = Long.SIZE - Long.numberOfLeadingZeros(bits);
            throw new IllegalArgumentException("not able write " + count + " bits for " + bits + " because used " + used + " bits");
        }

        // extra byte is not to worry about overflow while writing last byte
        ensureBytesCapacity(Long.BYTES + 1);
        int pos = buffer.writerIndex();
        int used = writeIndexBits;
        if (used == 0) {
            buffer.setLongLE(pos, bits);
            buffer.writerIndex(pos + div8(count));
            writeIndexBits = mod8(count);
            return;
        }

        int freeBit = 8 - used;
        byte prev = (byte) ((0xff >>> freeBit) & buffer.getByte(pos));
        buffer.setByte(pos, (byte) (prev | (bits & (0xff >>> used)) << used));
        long value = bits >>> freeBit;
        buffer.setLongLE(pos + 1, value);
        buffer.writerIndex(pos + div8(used + count));
        writeIndexBits = mod8(used + count);
    }

    @Override
    public void writeIntVarint1N(int n, int max) {
        if (n < 0 || n > max || max > 8) {
            throw new IllegalArgumentException("incorrect n=" + n + ", max=" + max);
        }

        ensureBytesCapacity(2);
        if (n == max) {
            writeMask(n, (byte) (0xff >>> (8 - n)));
            return;
        }

        writeMask(n + 1, (byte) (0xff >>> (8 - n)));
    }

    @Override
    public void writeBits(BitBuf src, long srcIndex, long length) {
        if (mod8(srcIndex) != 0 ) {
            throw new UnsupportedOperationException("Unsupported writeBits not alighted to byte pos: " + srcIndex);
        }

        if (writeIndexBits != 0) {
            throw new UnsupportedOperationException("Unsupported writeBits to not alighted pos: " + writerIndex());
        }

        ensureBytesCapacity(BitArray.arrayLengthForBits(length));
        if (src instanceof NettyBitBuf) {
            buffer.writeBytes(((NettyBitBuf) src).buffer, div8(srcIndex), BitArray.arrayLengthForBits(length));
        } else {
            buffer.writeBytes(src.array(), div8(srcIndex), BitArray.arrayLengthForBits(length));
        }
        writeIndexBits = mod8(length);
        if (writeIndexBits > 0) {
            buffer.writerIndex(buffer.writerIndex() - 1);
        }
    }

    private void writeMask(int size, byte mask) {
        int used = writeIndexBits;
        int freeBit = 8 - used;
        int pos = buffer.writerIndex();
        // TODO: uncomment it as only fixed dirty write
        //array[pos] |= mask << used;
        buffer.setByte(pos, (byte) ((0xff >>> freeBit) & buffer.getByte(pos) | (mask << used)));
        writeIndexBits += size;
        if (freeBit < size) {
            writeIndexBits -= 8;
            buffer.writerIndex(pos + 1);
            buffer.setByte(pos + 1, (byte) (mask >>> freeBit));
        }

        if (writeIndexBits == Byte.SIZE) {
            writeIndexBits = 0;
            buffer.writerIndex(buffer.writerIndex() + 1);
        }
    }

    @Override
    public void alignToByte() {
        if (writeIndexBits == 0) {
            return;
        }

        buffer.writerIndex(buffer.writerIndex() + 1);
        writeIndexBits = 0;
    }

    @Override
    public int bytesSize() {
        if (writeIndexBits > 0) {
            return buffer.writerIndex() + 1;
        }

        return buffer.writerIndex();
    }

    @Override
    public long readableBits() {
        long bits = buffer.readableBytes() << 3;
        return bits + writeIndexBits - readIndexBits;
    }

    @Override
    public boolean readBit() {
        int pos = buffer.readerIndex();
        byte b = buffer.getByte(pos);
        boolean result = BitArray.isBitSet(b, readIndexBits);
        if (++readIndexBits == Byte.SIZE) {
            readIndexBits = 0;
            buffer.readerIndex(pos + 1);
        }
        return result;
    }

    @Override
    public int readBitsToInt(int bitCount) {
        if (bitCount > 32) {
            throw new IllegalArgumentException("bit count cannot be greater than 32: " + bitCount);
        }
        return (int) readBitsToLong(bitCount);
    }

    @Override
    public long readBitsToLong(int bitCount) {
        if (bitCount == 0) {
            return 0;
        } else if (bitCount > 64) {
            throw new IllegalArgumentException("bit count cannot be greater than 64: " + bitCount);
        }

        long r = 0;
        int currentBit = 0;

        // align to byte
        {
            int bitsFromFirstByte = Math.min(bitCount, 8 - readIndexBits);
            r |= ((buffer.getByte(buffer.readerIndex()) & 0xffL) >>> readIndexBits) & ((1L << bitsFromFirstByte) - 1);
            currentBit += bitsFromFirstByte;
            readIndexBits += bitsFromFirstByte;
            if (readIndexBits == Byte.SIZE) {
                readIndexBits = 0;
                buffer.readerIndex(buffer.readerIndex() + 1);
            }
        }

        // read whole bytes
        while (bitCount - currentBit >= 8) {
            r |= (buffer.readByte() & 0xffL) << currentBit;
            currentBit += 8;
        }

        // read remaining bits
        {
            int bitsFromLastByte = bitCount - currentBit;

            // `if` is a protection againts buffer overrun
            if (bitsFromLastByte != 0) {
                r |= (buffer.getByte(buffer.readerIndex()) & ((1L << bitsFromLastByte) - 1)) << currentBit;
                readIndexBits = bitsFromLastByte;
            }

        }

        return r;
    }

    @Override
    public byte read8Bits() {
        if (readIndexBits == 0) {
            return buffer.readByte();
        }

        byte prev = buffer.readByte();
        int x = Byte.toUnsignedInt(prev) | (Byte.toUnsignedInt(buffer.getByte(buffer.readerIndex())) << 8);
        return (byte) (x >>> (readIndexBits));
    }

    @Override
    public int read32Bits() {
        if (readIndexBits == 0) {
            return buffer.readIntLE();
        }

        int pos = buffer.readerIndex();
        int used = readIndexBits;
        int free = 8 - used;

        int alight = Byte.toUnsignedInt(buffer.getByte(pos)) >>> used;
        int result = buffer.getIntLE(pos + 1) << free | alight;
        buffer.readerIndex(pos + Integer.BYTES);
        return result;
    }

    @Override
    public long read64Bits() {
        if (readIndexBits == 0) {
            return buffer.readLongLE();
        }

        int pos = buffer.readerIndex();
        int used = readIndexBits;
        int free = 8 - used;

        long alight = Byte.toUnsignedLong(buffer.getByte(pos)) >>> used;
        long result = buffer.getLongLE(pos + 1) << free | alight;
        buffer.readerIndex(pos + Long.BYTES);
        return result;
    }

    @Override
    public int readIntVarint1N(int max) {
        if (max > Byte.SIZE) {
            throw new IllegalArgumentException("max " + max + " > 8");
        }

        int r = 0;
        int pos = buffer.readerIndex();
        byte b = buffer.getByte(pos);
        while (r != max) {
            if ((b & (1 << (readIndexBits++))) != 0) {
                r++;
            } else {
                break;
            }

            if (readIndexBits == Byte.SIZE) {
                readIndexBits = 0;
                buffer.readerIndex(pos + 1);
                b = buffer.getByte(pos + 1);
            }
        }

        if (readIndexBits == Byte.SIZE) {
            readIndexBits = 0;
            buffer.readerIndex(pos + 1);
        }

        return r;
    }

    @Override
    public int readIntVarint8() {
        int result = 0;
        for (int shift = 0; shift < 64; shift += 7) {
            final byte b = read8Bits();
            result |= (long) (b & 0x7F) << shift;
            if ((b & 0x80) == 0) {
                return result;
            }
        }
        throw new IllegalStateException("malformed");
    }

    @Override
    public long readLongVarint8() {
        long result = 0;
        for (int shift = 0; shift < 64; shift += 7) {
            final byte b = read8Bits();
            result |= (long) (b & 0x7F) << shift;
            if ((b & 0x80) == 0) {
                return result;
            }
        }
        throw new IllegalStateException("malformed");
    }

    @Override
    public void resetWriterIndex() {
        buffer.resetWriterIndex();
        writeIndexBits = 0;
    }

    @Override
    public void resetReadIndex() {
        buffer.resetReaderIndex();
        readIndexBits = 0;
    }

    @Override
    public long writerIndex() {
        long bits = buffer.writerIndex() << 3;
        return bits + writeIndexBits;
    }

    @Override
    public void writerIndex(long writerIndex) {
        buffer.writerIndex(div8(writerIndex));
        writeIndexBits = mod8(writerIndex);
    }

    @Override
    public long readerIndex() {
        long bits = buffer.readerIndex() << 3;
        return bits + readIndexBits;
    }

    @Override
    public void readerIndex(long readerIndex) {
        buffer.readerIndex(div8(readerIndex));
        readIndexBits = mod8(readerIndex);
    }

    @Override
    public void skipBits(long length) {
        long bits = readIndexBits + length;
        buffer.skipBytes(div8(bits));
        readIndexBits = mod8(bits);
    }

    @Override
    public BitBuf asReadOnly() {
        return new NettyBitBuf(buffer.duplicate().asReadOnly(), writeIndexBits, readIndexBits);
    }

    @Override
    public BitBuf slice(long index, long length) {
        if (mod8(index) != 0) {
            throw new UnsupportedOperationException("Unsupported slice not alighted to byte pos: " + index);
        }

        var sliced = buffer.slice(div8(index), BitArray.arrayLengthForBits(length));
        return new NettyBitBuf(sliced, length);
    }

    @Override
    public BitBuf duplicate() {
        return new NettyBitBuf(buffer.duplicate(), writeIndexBits, readIndexBits);
    }

    @Override
    public BitBuf copy() {
        return copy(readerIndex(), readableBits());
    }

    @Override
    public BitBuf copy(long index, long length) {
        if (mod8(index) != 0) {
            throw new UnsupportedOperationException("Unsupported slice not alighted to byte pos: " + index);
        }

        return new NettyBitBuf(buffer.copy(div8(index), BitArray.arrayLengthForBits(length)), length);
    }

    @Override
    public byte[] array() {
        byte[] array = new byte[bytesSize()];
        buffer.getBytes(0, array, 0, array.length);
        return array;
    }

    @Override
    public int arrayOffset() {
        return 0;
    }

    @Override
    public int refCnt() {
        return buffer.refCnt();
    }

    @Override
    public BitBuf retain() {
        buffer.retain();
        return this;
    }

    @Override
    public BitBuf retain(int increment) {
        buffer.retain(increment);
        return this;
    }

    @Override
    public BitBuf touch() {
        buffer.touch();
        return this;
    }

    @Override
    public BitBuf touch(Object hint) {
        buffer.touch(hint);
        return this;
    }

    @Override
    public NettyBitBuf allocate(int byteCapacity) {
        var allocator = buffer.alloc();
        final ByteBuf buffer;
        if (this.buffer.isDirect()) {
            buffer = allocator.directBuffer(byteCapacity);
        } else {
            buffer = allocator.heapBuffer(byteCapacity);
        }
        return new NettyBitBuf(buffer, 0);
    }

    @Override
    public boolean release() {
        return buffer.release();
    }

    @Override
    public boolean release(int decrement) {
        return buffer.release(decrement);
    }

    @Override
    public long memorySizeIncludingSelf() {
        return SELF_SIZE + BYTEBUF_SIZE + buffer.capacity();
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (!(o instanceof BitBuf)) return false;

        if (!(o instanceof NettyBitBuf)) {
            BitBuf that = (BitBuf) o;
            if (readableBits() != that.readableBits()) {
                return false;
            }

            return Arrays.equals(array(), 0, bytesSize(), that.array(), 0, that.bytesSize());
        }

        NettyBitBuf that = (NettyBitBuf) o;
        if (this.readableBits() != that.readableBits()) {
            return false;
        }

        if (writeIndexBits != that.writeIndexBits) {
            return false;
        }

        if (readIndexBits != that.readIndexBits) {
            return false;
        }

        int fullBytes = buffer.writerIndex();
        if (fullBytes > 0 && !ByteBufUtil.equals(buffer, 0, that.buffer, 0, fullBytes)) {
            return false;
        }

        if (writeIndexBits != 0) {
            int a = this.buffer.getByte(fullBytes) & ((1 << writeIndexBits) - 1);
            int b = that.buffer.getByte(fullBytes) & ((1 << writeIndexBits) - 1);
            return a == b;
        }

        return true;
    }

    @Override
    public int hashCode() {
        int hash = Long.hashCode(writeIndexBits);
        return 31 * hash + buffer.hashCode();
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder()
            .append(StringUtil.simpleClassName(this))
            .append("(ridx: ").append(buffer.readerIndex())
            .append("[").append(readIndexBits).append("]")
            .append(", widx: ").append(buffer.writerIndex())
            .append("[").append(writeIndexBits).append("]")
            .append(", cap: ").append(buffer.capacity());
        if (buffer.maxCapacity() != Integer.MAX_VALUE) {
            sb.append('/').append(buffer.maxCapacity());
        }
        sb.append(", [bits: ").append(readableBits()).append(": ");

        for (int index = buffer.readerIndex(); index < Math.min(bytesSize(), 100); index++) {
            byte b = buffer.getByte(index);
            for (int bit = 0; bit < Byte.SIZE; bit++) {
                if (buffer.readerIndex() == index && bit == readIndexBits) {
                    sb.append("R");
                }

                if (buffer.writerIndex() == index && bit == writeIndexBits) {
                    sb.append("W");
                }

                if (BitArray.isBitSet(b, bit)) {
                    sb.append("1");
                } else {
                    sb.append("0");
                }
            }
            sb.append(" ");
        }
        if (bytesSize() > 100) {
            sb.append("...");
        }
        sb.append("])");
        return sb.toString();
    }

    @Override
    public boolean isDirect() {
        return buffer.isDirect();
    }
}
