package ru.yandex.webmaster3.storage.util.yt;

import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;

import io.airlift.compress.snappy.SnappyDecompressor;
import org.apache.commons.io.IOUtils;

import ru.yandex.webmaster3.core.util.ByteStreamUtil;

/**
 * @author aherman
 */
class YtSnappyCompressedStream extends FilterInputStream {
    private static final int FORMAT_VERIONS = 1;

    private static final int SIZE_OF_UINT32 = 4;
    private static final int SIZE_OF_UINT16 = 2;
    private static final int SIZE_OF_UINT8 = 1;

    private static final int HEADER_SIGNATURE_OFFSET = 0;
    private static final int HEADER_VERSION_OFFSET = HEADER_SIGNATURE_OFFSET + 4;
    private static final int HEADER_BLOCK_SIZE_OFFSET = HEADER_VERSION_OFFSET + SIZE_OF_UINT32;
    private static final int HEADER_SIZE = HEADER_BLOCK_SIZE_OFFSET + SIZE_OF_UINT16;

    private static final int BLOCK_DATA_SIZE_OFFSET = 0;
    private static final int BLOCK_COMPRESSED_OFFSET = BLOCK_DATA_SIZE_OFFSET + SIZE_OF_UINT16;
    private static final int BLOCK_HEADER_SIZE = BLOCK_COMPRESSED_OFFSET + SIZE_OF_UINT8;

    private int defaultBlockSize;

    private int bufferPosition;
    private int bufferSize;
    private byte[] buffer;

    private State state = State.HEADER;

    public YtSnappyCompressedStream(InputStream parent) {
        super(parent);
    }

    @Override
    public int read() throws IOException {
        if (hasMoreData()) {
            return readByte();
        }
        return -1;
    }

    @Override
    public int read(byte[] b) throws IOException {
        if (hasMoreData()) {
            return readBuffer(b, 0, b.length);
        }
        return -1;
    }

    @Override
    public int read(byte[] b, int off, int len) throws IOException {
        if (hasMoreData()) {
            return readBuffer(b, off, len);
        }
        return -1;
    }

    private boolean hasMoreData() throws IOException {
        while (state != State.END) {
            if (state == State.HEADER) {
                state = readStreamHeader();
            } else if (bufferPosition < bufferSize) {
                return true;
            } else {
                state = readBlock();
            }
        }
        return false;
    }

    private int readByte() {
        return buffer[bufferPosition++] & 0xFF;
    }

    private int readBuffer(byte[] b, int off, int len) {
        int maxCopySize = Math.min(bufferSize - bufferPosition, len);
        System.arraycopy(buffer, bufferPosition, b, off, maxCopySize);
        bufferPosition += maxCopySize;
        return maxCopySize;
    }

    private State readStreamHeader() throws IOException {
        byte[] header = new byte[HEADER_SIZE];
        if (IOUtils.read(in, header) != HEADER_SIZE) {
            throw new IOException("Corrupted stream, unable to read global header");
        }
        validateSignature(header, HEADER_SIGNATURE_OFFSET);
        validateFormatVersion(header, HEADER_VERSION_OFFSET);
        defaultBlockSize = ByteStreamUtil.readUInt16LE(header, HEADER_BLOCK_SIZE_OFFSET);
        if (defaultBlockSize < 0 || defaultBlockSize >= (2 << (SIZE_OF_UINT16 * 8))) {
            throw new IOException("Corrupted stream, illegal block size: " + defaultBlockSize);
        }
        buffer = new byte[defaultBlockSize];
        return State.BLOCK;
    }

    private void validateSignature(byte[] header, int offset) throws IOException {
        if (header[offset + 0] != 'S' ||
                header[offset + 1] != 'n' ||
                header[offset + 2] != 'a' ||
                header[offset + 3] != 'p')
        {
            throw new IOException("Corrupted stream, signature mismatch");
        }
    }

    private void validateFormatVersion(byte[] header, int offset) throws IOException {
        int version = ByteStreamUtil.readIntLE(header, offset);
        if (version != FORMAT_VERIONS) {
            throw new IOException("Unknown format version: " + version);
        }
    }

    private State readBlock() throws IOException {
        byte[] blockHeader = new byte[BLOCK_HEADER_SIZE];
        int headerSize = IOUtils.read(in, blockHeader);
        if (headerSize == 0) {
            return State.END;
        }
        if (headerSize != BLOCK_HEADER_SIZE) {
            throw new IOException("Corrupted stream, unable to read block header");
        }
        boolean compressed = ByteStreamUtil.readUInt8(blockHeader, BLOCK_COMPRESSED_OFFSET) == 1;
        int dataSize = ByteStreamUtil.readUInt16LE(blockHeader, BLOCK_DATA_SIZE_OFFSET);
        if (dataSize < 0 || dataSize >= (2 << (SIZE_OF_UINT16 * 8))) {
            throw new IOException("Corrupted stream, illegal block size: " + dataSize);
        }
        if (dataSize == 0) {
            return State.END;
        }

        bufferPosition = 0;
        if (compressed) {
            byte[] compressedBuffer = new byte[dataSize];
            int actualSize = IOUtils.read(in, compressedBuffer);
            if (actualSize != dataSize) {
                throw new IOException(
                        "Corrupted stream, unable to read block data: expected=" + dataSize + " actual=" + actualSize);
            }
            SnappyDecompressor decompressor = new SnappyDecompressor();
            bufferSize = decompressor.decompress(compressedBuffer, 0, dataSize, buffer, 0, buffer.length);
        } else {
            bufferSize = IOUtils.read(in, buffer, 0, dataSize);
            if (bufferSize != dataSize) {
                throw new IOException(
                        "Corrupted stream, unable to read block data: expected=" + dataSize + " actual=" + bufferSize);
            }
        }
        return State.BLOCK;
    }

    private enum State {
        HEADER,
        BLOCK,
        END,
    }
}
