#pragma once

#include "columns.h"
#include "count_codec.h"
#include "frame_codec.h"
#include "points.h"
#include "time_codec.h"

namespace NSolomon::NTs {

constexpr size_t ColumnBits = 5;

/**
 * Base class for timeseries encoders.
 */
template <typename TDerived, typename TValue>
class TBaseTsEncoder {
public:
    TBaseTsEncoder(TColumnSet columns, TBitWriter* writer)
        : Columns_{columns}
        , Writer_{writer}
    {
        FrameEncoder_.InitFrame(Writer_);
    }

    size_t Flush() noexcept {
        Writer_->Flush();
        return Writer_->Pos();
    }

    template <typename TPoint>
    void EncodePoint(const TPoint& point) {
        // (1) time
        TimeEncoder_.Encode(Writer_, point.Time);

        // (2) commands
        EncodeCommands(point);

        // (3) value
        static_cast<TDerived*>(this)->EncodeValue(Writer_, TValueGetter<TPoint, TValue>::Get(point));

        // (4) count
        if (Columns_.IsSet(EColumn::COUNT)) {
            CountEncoder_.Encode(Writer_, point.Count);
        }

        ++FramePointCount_;
    }

    bool CloseFrame() {
        return FrameEncoder_.CloseFrame(Writer_, [this]() {
            WriteState();
        });
    }

    size_t FrameSize() const {
        return FramePointCount_;
    }

private:
    void WriteState() {
        auto prevTime = TimeEncoder_.Prev();
        Writer_->WriteInt64(prevTime.MilliSeconds());

        // XXX: weired format design: this point count splits time encoder state
        Writer_->WriteVarInt32Mode(FramePointCount_);
        TotalPointCount_ += FramePointCount_;
        FramePointCount_ = 0;

        Writer_->WriteBit(TimeEncoder_.Millis());
        Writer_->WriteVarInt64Mode(TimeEncoder_.DeltaMillis());
        TimeEncoder_ = {};

        Writer_->WriteVarInt32Mode(StepMillis_);
        StepMillis_ = 0;

        Writer_->WriteVarInt64Mode(CountEncoder_.Prev());
        CountEncoder_ = {};

        Writer_->WriteBit(Merge_);
        Merge_ = false;

        static_cast<TDerived*>(this)->WriteState(Writer_);
    }

    template <typename TPoint>
    void EncodeCommands(const TPoint& point) {
        static_cast<TDerived*>(this)->EncodeCommand(Writer_, TValueGetter<TPoint, TValue>::Get(point));

        if (Y_UNLIKELY(Columns_.IsSet(EColumn::MERGE) && Merge_ != point.Merge)) {
            Merge_ = point.Merge;
            WriteColumn(EColumn::MERGE);
        }

        if (Y_UNLIKELY(Columns_.IsSet(EColumn::STEP) && StepMillis_ != point.Step.MilliSeconds())) {
            StepMillis_ = point.Step.MilliSeconds();
            WriteColumn(EColumn::STEP);
            Writer_->WriteVarInt32(StepMillis_);
        }

        Writer_->WriteBit(false); // end of commands flag
    }

    void WriteColumn(EColumn c) {
        Writer_->WriteBit(true);
        Writer_->WriteInt32(static_cast<ui32>(c), ColumnBits);
    }

private:
    TColumnSet Columns_;
    TBitWriter* Writer_;

    TFrameEncoder FrameEncoder_;
    TTimeEncoder TimeEncoder_;
    TCountEncoder CountEncoder_;

    ui32 StepMillis_{0};
    bool Merge_{false};

    ui32 FramePointCount_{0};
    ui32 TotalPointCount_{0};
};

/**
 * Base class for time series decoders.
 */
template <typename TDerived, typename TValue>
class TBaseTsDecoder {
public:
    TBaseTsDecoder(TColumnSet columns, TBitSpan data)
        : Columns_{columns}
        , Reader_{data}
    {
        NextFrame();
    }

    template <typename TPoint>
    bool NextPoint(TPoint* point) {
        if (!HasNext()) {
            return false;
        }

        // (1) time
        point->Time = TimeDecoder_.Decode(&Reader_);

        // (2) commands
        DecodeCommands();

        // (3) value
        static_cast<TDerived*>(this)->DecodeValue(&Reader_, TValueGetter<TPoint, TValue>::Mut(point));

        if (Columns_.IsSet(EColumn::MERGE)) {
            point->Merge = Merge_;
        } else {
            point->Merge = false;
        }

        if (Columns_.IsSet(EColumn::STEP)) {
            point->Step = TDuration::MilliSeconds(StepMillis_);
        } else {
            point->Step = TDuration::Zero();
        }

        // (4) count
        if (Columns_.IsSet(EColumn::COUNT)) {
            point->Count = CountDecoder_.Decode(&Reader_);
        } else {
            point->Count = 0;
        }

        return true;
    }

    bool HasNext() const noexcept {
        return Reader_.Left() > 0;
    }

private:
    void DecodeCommands() {
        while (Reader_.ReadBit()) {
            ui8 columnIdx = Reader_.ReadInt8(ColumnBits);
            EColumn column = TColumnSet::AllColumns[columnIdx];

            if (column == EColumn::MERGE) {
                Merge_ = !Merge_;
            } else if (column == EColumn::STEP) {
                auto stepMillis = Reader_.ReadVarInt32();
                Y_ENSURE(stepMillis, "cannot read step millis");
                StepMillis_ = *stepMillis;
            } else {
                static_cast<TDerived*>(this)->DecodeCommand(&Reader_);
            }
        }
    }

    void NextFrame() {
        if (Y_LIKELY(FrameDecoder_.Next(&Reader_))) {
            Reader_.SetPos(FrameDecoder_.PayloadIndex());
        } else {
            Reader_.SetPos(Reader_.Size());
        }
    }

    void EnsureFrameRead() {
        size_t frameEnd = FrameDecoder_.PayloadIndex() + FrameDecoder_.PayloadBits();
        if (Reader_.Pos() >= frameEnd) {
            NextFrame();
            ResetState();
        }
    }

    void ResetState() {
        TimeDecoder_ = {};
        CountDecoder_ = {};
        Merge_ = false;
        StepMillis_ = 0;
        static_cast<TDerived*>(this)->Reset();
    }

private:
    const TColumnSet Columns_;
    TBitReader Reader_;

    TFrameDecoder FrameDecoder_;
    TTimeDecoder TimeDecoder_;
    TCountDecoder CountDecoder_;

    ui32 StepMillis_{0};
    bool Merge_{false};
};

} // namespace NSolomon::NTs
