package ru.yandex.solomon.codec.compress;

import java.util.NoSuchElementException;
import java.util.function.Function;

import ru.yandex.solomon.codec.bits.BitArray;
import ru.yandex.solomon.codec.bits.BitBuf;

import static ru.yandex.solomon.codec.compress.FrameEncoder.FRAME_FOOTER_SIZE;
import static ru.yandex.solomon.codec.compress.FrameEncoder.FRAME_HEADER_SIZE;

/**
 * @author Vladimir Gordiychuk
 */
public class FrameIterator {
    private final BitBuf buffer;

    private long headerIdx = -1;
    private long payloadBits = -1;
    private long footerIdx = -1;
    private long footerBits = -1;

    public FrameIterator(BitBuf buffer) {
        this.buffer = buffer;
    }

    private boolean hasNext() {
        if (headerIdx == -1) {
            return buffer.readableBits() > 0;
        }

        if (footerIdx == -1) {
            return false;
        }

        return buffer.writerIndex() > footerIdx + FRAME_FOOTER_SIZE + footerBits;
    }

    public boolean next() {
        if (!hasNext()) {
            return false;
        }

        if (headerIdx != -1) {
            headerIdx = footerIdx + FRAME_FOOTER_SIZE + footerBits;
        } else {
            headerIdx = buffer.readerIndex();
        }

        readFrameMeta();
        return true;
    }

    private void readFrameMeta() {
        buffer.readerIndex(headerIdx);
        int payloadBytes = buffer.read32Bits();
        if (payloadBytes < 0) {
            throw new IllegalStateException("frame " + headerIdx + " corrupted: " + buffer);
        }

        int payloadBitsInLastByte = buffer.read8Bits();
        if (payloadBitsInLastByte < 0 || payloadBitsInLastByte > 8) {
            throw new IllegalStateException("frame " + headerIdx + " corrupted: " + buffer);
        }

        if (payloadBytes > 0) {
            payloadBits = BitArray.toBits(payloadBytes, payloadBitsInLastByte);
            footerIdx = headerIdx + FRAME_HEADER_SIZE + ((long) payloadBytes << 3);
            footerBits = readFooterBits();
        } else {
            payloadBits = buffer.readableBits();
            footerIdx = -1;
            footerBits = -1;
        }
    }

    public void moveTo(long headerIdx) {
        this.headerIdx = headerIdx;
        readFrameMeta();
    }

    private long readFooterBits() {
        buffer.readerIndex(footerIdx);
        long bytes = buffer.read32Bits();
        if (bytes < 0) {
            throw new IllegalStateException("frame " + headerIdx + " corrupted: " + buffer);
        }
        return bytes << 3;
    }

    public boolean hasFooter() {
        return footerIdx != -1;
    }

    public long headerIndex() {
        return headerIdx;
    }

    public long payloadIndex() {
        return headerIdx + FRAME_HEADER_SIZE;
    }

    public long footerIndex() {
        return footerIdx + FRAME_FOOTER_SIZE;
    }

    public long payloadBits() {
        return payloadBits;
    }

    public long frameSize() {
        if (footerIdx != -1) {
            return (footerIdx + FRAME_FOOTER_SIZE + footerBits) - headerIdx;
        }

        return FRAME_HEADER_SIZE + payloadBits;
    }

    public <T> T readFooter(Function<BitBuf, T> fn) {
        if (footerIdx == -1) {
            throw new NoSuchElementException();
        }

        buffer.readerIndex(footerIdx + FRAME_FOOTER_SIZE);
        T result = fn.apply(buffer);
        long readBits = buffer.readerIndex() - (footerIdx + FRAME_FOOTER_SIZE);
        if (readBits > footerBits) {
            throw new IllegalStateException("Consumer read from footer " + readBits + " where footer size " + footerBits);
        }
        return result;
    }
}
