package ru.yandex.solomon.codec.compress;

import javax.annotation.ParametersAreNonnullByDefault;

import com.google.protobuf.CodedOutputStream;

import ru.yandex.solomon.codec.bits.BitArray;
import ru.yandex.solomon.codec.bits.BitBuf;
import ru.yandex.solomon.model.point.AggrPointData;
import ru.yandex.solomon.model.point.column.CountColumn;
import ru.yandex.solomon.model.point.column.HasColumnSet;
import ru.yandex.solomon.model.point.column.StepColumn;
import ru.yandex.solomon.model.point.column.StockpileColumn;
import ru.yandex.solomon.model.point.column.TsColumn;

/**
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
public abstract class AbstractTimeSeriesOutputStream implements TimeSeriesOutputStream {
    private static final int MIN_FRAME_BYTES_SIZE = 1 << 10; // 1 KiB

    private BitBuf out;

    private long prevTsMillis;
    private long prevTsMillisDelta;
    private boolean millis;

    // optional columns

    private long prevStepMillis;
    private long prevCount;
    private boolean prevMerge;

    // meta data
    private long frameIdx;
    private int frameRecordCount = 0;
    private int recordCount = 0;

    public AbstractTimeSeriesOutputStream(AbstractTimeSeriesOutputStream copy) {
        this.out = copy.out.copy();
        this.prevTsMillis = copy.prevTsMillis;
        this.prevTsMillisDelta = copy.prevTsMillisDelta;
        this.millis = copy.millis;
        this.prevStepMillis = copy.prevStepMillis;
        this.prevCount = copy.prevCount;
        this.prevMerge = copy.prevMerge;
        this.frameIdx = copy.frameIdx;
        this.frameRecordCount = copy.frameRecordCount;
        this.recordCount = copy.recordCount;
    }

    public AbstractTimeSeriesOutputStream(BitBuf out, int records) {
        this.out = out;
        this.recordCount = records;
        initFrame();
    }

    private void encodeCommands(int columnSet, AggrPointData point) {
        writeValueCommand(out, columnSet, point);

        if (HasColumnSet.hasColumn(columnSet, StockpileColumn.MERGE)) {
            writeMergeCommand(point.merge);
        }

        if (HasColumnSet.hasColumn(columnSet, StockpileColumn.STEP)) {
            writeStepMillisCommand(point.stepMillis);
        }

        CommandEncoder.encodeEndOfCommands(out);
    }

    protected abstract void writeValueCommand(BitBuf stream, int columnSet, AggrPointData point);
    protected abstract void writeValue(BitBuf stream, int columnSet, AggrPointData point);

    private void writeMergeCommand(boolean merge) {
        if (merge == prevMerge) {
            return;
        }

        prevMerge = merge;
        CommandEncoder.encodeCommandPrefix(out, StockpileColumn.MERGE);
    }

    private void writeStepMillisCommand(long stepMillis) {
        if (prevStepMillis == stepMillis) {
            return;
        }

        CommandEncoder.encodeCommandPrefix(out, StockpileColumn.STEP);
        out.writeIntVarint8(Math.toIntExact(stepMillis));
        prevStepMillis = stepMillis;
    }

    private void writeTsMillis(long tsMillis) {
        TsColumn.validateOrThrow(tsMillis);
        if (frameRecordCount < 2) {
            // first two point encode as is, because we can't calculate delta of delta right now
            out.write64Bits(tsMillis);

            if (prevTsMillis != 0) {
                prevTsMillisDelta = tsMillis - prevTsMillis;
            }
            prevTsMillis = tsMillis;
            return;
        }

        long delta = tsMillis - prevTsMillis;
        // fast path
        if (delta == prevTsMillisDelta) {
            VarintEncoder.writeVarintMode64(out, 0);
            prevTsMillis = tsMillis;
            return;
        }

        // slow path
        long deltaOfDelta = delta - prevTsMillisDelta;
        if (!millis) {
            if (deltaOfDelta % 1000 != 0) {
                out.write8Bits(0xff);
                millis = true;
                // could probably store more information here
            }
        }

        long ddw = millis ? deltaOfDelta : deltaOfDelta / 1000;
        long ddz = CodedOutputStream.encodeZigZag64(ddw);
        VarintEncoder.writeVarintMode64(out, ddz);

        prevTsMillisDelta = delta;
        prevTsMillis = tsMillis;
    }

    private void writeCountInt64(long count) {
        if (frameRecordCount() == 0) {
            out.writeLongVarint8(count);
        } else {
            long delta = CodedOutputStream.encodeZigZag64(count - prevCount);
            VarintEncoder.writeVarintMode64(out, delta);
        }

        prevCount = count;
    }

    @Override
    public final void writePoint(int columnSet, AggrPointData point) {
        writeTsMillis(point.tsMillis);
        encodeCommands(columnSet, point);
        writeValue(out, columnSet, point);

        if (HasColumnSet.hasColumn(columnSet, StockpileColumn.COUNT)) {
            writeCountInt64(point.count);
        }

        frameRecordCount++;
    }

    @Override
    public int recordCount() {
        return recordCount + frameRecordCount;
    }

    private void initFrame() {
        frameIdx = FrameEncoder.initFrame(out);
    }

    @Override
    public boolean closeFrame() {
        if (frameBytesCount() < MIN_FRAME_BYTES_SIZE) {
            return false;
        }

        FrameEncoder.finishFrame(out, frameIdx, this::dumpAndResetState);
        frameIdx = FrameEncoder.initFrame(out);
        return true;
    }

    @Override
    public void forceCloseFrame() {
        if (frameRecordCount == 0) {
            return;
        }

        FrameEncoder.finishFrame(out, frameIdx, this::dumpAndResetState);
        frameIdx = FrameEncoder.initFrame(out);
    }

    @Override
    public long getLatestFrameIdx() {
        return frameIdx;
    }

    @Override
    public long getLastTsMillis() {
        return prevTsMillis;
    }

    @Override
    public long getLastFrameIdx() {
        return frameIdx;
    }

    @Override
    public int frameRecordCount() {
        return frameRecordCount;
    }

    @Override
    public int frameBytesCount() {
        long bits = out.writerIndex() - frameIdx;
        if (frameRecordCount == 0) {
            bits -= FrameEncoder.FRAME_HEADER_SIZE;
        }

        return BitArray.arrayLengthForBits(bits);
    }

    @Override
    public BitBuf getCompressedData() {
        if (frameRecordCount == 0) {
            return out.slice(0, frameIdx).asReadOnly();
        }
        return out.asReadOnly();
    }

    @Override
    public void ensureCapacity(int columnSet, int capacity) {
        // temp disable init stream with pre estimate capacity
//        int requiredSizeInBytes = CompressStreamFactory.estimatePointSize(columnSet) * capacity;
//        out.reserveAdditionalBytes(requiredSizeInBytes);
    }

    public void ensureCapacity(int bytes) {
        out.ensureBytesCapacity(bytes);
    }

    protected abstract long memorySelfSize();

    @Override
    public long memorySizeIncludingSelf() {
        return memorySelfSize() + out.memorySizeIncludingSelf();
    }

    @Override
    public void close() {
        out.release();
    }

    @Override
    public int bytesCount() {
        if (frameRecordCount == 0) {
            return BitArray.arrayLengthForBits(frameIdx);
        }

        return out.bytesSize();
    }

    private void dumpAndResetState(BitBuf buffer) {
        buffer.write64Bits(prevTsMillis);
        prevTsMillis = 0;

        VarintEncoder.writeVarintMode32(out, frameRecordCount);
        recordCount += frameRecordCount;
        frameRecordCount = 0;

        buffer.writeBit(millis);
        millis = false;

        VarintEncoder.writeVarintMode64(buffer, prevTsMillisDelta);
        prevTsMillisDelta = 0;

        VarintEncoder.writeVarintMode32(buffer, Math.toIntExact(prevStepMillis));
        prevStepMillis = 0;

        VarintEncoder.writeVarintMode64(buffer, prevCount);
        prevCount = 0;

        buffer.writeBit(prevMerge);
        prevMerge = false;

        dumpAndResetAdditionalState(buffer);
    }

    protected abstract void dumpAndResetAdditionalState(BitBuf buffer);

    void continueLastClosedFrame(long frameIdx) {
        var source = getCompressedData();
        var buffer = source.asReadOnly();
        var frameIterator = new FrameIterator(buffer);
        frameIterator.moveTo(frameIdx);
        buffer.readerIndex(frameIterator.footerIndex());
        restoreState(buffer);
        this.recordCount -= this.frameRecordCount;
        this.out = source.slice(source.readerIndex(), frameIterator.payloadIndex() + frameIterator.payloadBits()).copy();
        FrameEncoder.continueFrame(this.out, frameIdx);
        this.frameIdx = frameIdx;

        if (frameIterator.next()) {
            throw new IllegalStateException("Frame idx " + frameIdx + " not valid, has more at " + frameIterator.headerIndex());
        }
        source.release();
    }

    private void restoreState(BitBuf buffer) {
        prevTsMillis = buffer.read64Bits();
        TsColumn.validateOrThrow(prevTsMillis);
        frameRecordCount = VarintEncoder.readVarintMode32(buffer);
        millis = buffer.readBit();
        prevTsMillisDelta = VarintEncoder.readVarintMode64(buffer);
        prevStepMillis = VarintEncoder.readVarintMode32(buffer);
        StepColumn.validateOrThrow(prevStepMillis);
        prevCount = VarintEncoder.readVarintMode64(buffer);
        CountColumn.validateOrThrow(prevCount);
        prevMerge = buffer.readBit();

        restoreAdditionalState(buffer);
    }

    protected abstract void restoreAdditionalState(BitBuf buffer);
}
