#pragma once

#include "bits.h"
#include "bit_buffer.h"

#include <util/generic/buffer.h>
#include <util/generic/cast.h>

#ifndef _little_endian_
#  error "unsupported platform"
#endif

namespace NSolomon::NTs {

/**
 * Bit-level stream writer. Values written as LE-ordered to a stream.
 */
class TBitWriter {
public:
    /**
     * Initialize bit writer.
     *
     * Writer will use, but does not own the given buffer. It is responsibility of the caller
     * to make buffer life-time greater than writer life-time.
     *
     * @param storage   pointer to a buffer
     */
    explicit TBitWriter(TBitBuffer* storage) noexcept
        : Storage_{storage}
        , Pos_{0}
    {
    }

    ~TBitWriter() {
        Flush();
    }

    /**
     * Update buffer's size.
     */
    void Flush() noexcept {
        Storage_->Resize(Pos_);
    }

    /**
     * Write single bit to a stream.
     *
     * @param value  bit value, {@code true} to write {@code 1} and {@code false} to write {@code 0}.
     */
    void WriteBit(bool value) {
        auto [index, used] = SplitBitIndex(Pos_);
        Advance(1);

        if (value) {
            Data()[index] |= ui8(1) << used;
        } else {
            Data()[index] &= ~(ui8(1) << used);
        }
    }

    /**
     * Write ui8 value to a stream.
     *
     * @param value  value to write.
     */
    void WriteInt8(ui8 value) {
        auto [index, used] = SplitBitIndex(Pos_);
        Advance(BitsSize<ui8>());

        if (Y_UNLIKELY(used == 0)) {
            Data()[index] = value;
        } else {
            ui16 prev = ReadUnaligned<ui16>(Data() + index);
            ui16 updated = LowerBits(prev, used) | (ui16(value) << used);
            WriteUnaligned<ui16>(Data() + index, updated);
        }
    }

    /**
     * Write ui32 value to a stream.
     *
     * @param value  value to write.
     */
    void WriteInt32(ui32 value) {
        WriteFixed(value);
    }

    /**
     * Write ui64 value to a stream.
     *
     * @param value  value to write.
     */
    void WriteInt64(ui64 value) {
        WriteFixed(value);
    }

    /**
     * Write double value to a stream.
     *
     * @param value  value to write.
     */
    void WriteDouble(double value) {
        WriteFixed(BitCast<ui64>(value));
    }

    /**
     * Write maximum {@code bits} lower bits of ui8 value to a stream.
     *
     * @param value  value to write.
     * @param bits   number of lower bits (must be <= 8)
     */
    void WriteInt8(ui8 value, size_t bits) noexcept {
        Y_VERIFY_DEBUG(bits > 0 && bits <= BitsSize<ui8>());

        auto [index, used] = SplitBitIndex(Pos_);
        Advance(bits);

        value = LowerBits(value, bits);

        if (used + bits <= BitsSize<ui8>()) {
            Data()[index] = LowerBits(Data()[index], used) | (value << used);
        } else {
            ui16 prev = ReadUnaligned<ui16>(Data() + index);
            ui16 updated = LowerBits(prev, used) | (ui16(value) << used);
            WriteUnaligned<ui16>(Data() + index, updated);
        }
    }

    /**
     * Write maximum {@code bits} lower bits of ui32 value to a stream.
     *
     * @param value  value to write.
     * @param bits   number of lower bits (must be <= 32)
     */
    void WriteInt32(ui32 value, size_t bits) noexcept {
        WriteBits(value, bits);
    }

    /**
     * Write maximum {@code bits} lower bits of ui64 value to a stream.
     *
     * @param value  value to write.
     * @param bits   number of lower bits (must be <= 64)
     */
    void WriteInt64(ui64 value, size_t bits) noexcept {
        WriteBits(value, bits);
    }

    /**
     * Write ui32 value in LEB128 (@see https://en.wikipedia.org/wiki/LEB128) format to a stream.
     *
     * @param value  value to write.
     */
    void WriteVarInt32(ui32 value) noexcept;

    /**
     * Write ui64 value in LEB128 (@see https://en.wikipedia.org/wiki/LEB128) format to a stream.
     *
     * @param value  value to write.
     */
    void WriteVarInt64(ui64 value) noexcept;

    /**
     * Write ui32 value as variable length number in prefixed format.
     *
     * @param value  value to write.
     */
    void WriteVarInt32Mode(ui32 value);

    /**
     * Write ui64 value as variable length number in prefixed format.
     *
     * @param value  value to write.
     */
    void WriteVarInt64Mode(ui64 value);

    /**
     * Write continuous sequence of {@code n} 1 bits followed by 0 bit.
     * If {@code n} == {@code max} then 0 bit is not written.
     *
     * @param n    number of 1 bits to write (must be <= max)
     * @param max  maximum number of bits to write (must be <= 8).
     */
    void WriteOnes(ui8 n, ui8 max) noexcept {
        Y_VERIFY_DEBUG(n <= max);
        Y_VERIFY_DEBUG(max <= BitsSize<ui8>());

        auto [index, used] = SplitBitIndex(Pos_);
        Advance(n + static_cast<ui8>(n != max));

        ui8 mask = ~(Max<ui8>() << n);
        if (n + used <= BitsSize<ui8>()) {
            Data()[index] = LowerBits(Data()[index], used) | (mask << used);
        } else {
            ui16 prev = ReadUnaligned<ui16>(Data() + index);
            ui16 updated = LowerBits(prev, used) | (ui16(mask) << used);
            WriteUnaligned<ui16>(Data() + index, updated);
        }
    }

    /**
     * @return current bit position.
     */
    size_t Pos() const noexcept {
        return Pos_;
    }

    /**
     * Update current bit position.
     *
     * @param pos  new position (must be <= underlying storage size}.
     */
    void SetPos(size_t pos) noexcept {
        Y_VERIFY_DEBUG(pos <= Storage_->Size());
        Pos_ = pos;
    }

    /**
     * Align current position up to a byte boundary.
     */
    void AlignToByte() noexcept {
        Pos_ = NTs::AlignToByte(Pos_);
        Storage_->Resize(Pos_);
    }

private:
    template <typename T>
    void WriteFixed(T value) {
        auto [index, used] = SplitBitIndex(Pos_);
        Advance(BitsSize<T>());

        if (Y_UNLIKELY(used == 0)) {
            WriteUnaligned<T>(Data() + index, value);
        } else {
            T prev = ReadUnaligned<T>(Data() + index);
            T updated = LowerBits(prev, used) | (value << used);
            WriteUnaligned<T>(Data() + index, updated);

            // XXX: msb bits can be wiped and it's OK
            Data()[index + sizeof(value)] = value >> (BitsSize<T>() - used);
        }
    }

    template <typename T>
    void WriteBits(T value, size_t bits) {
        Y_VERIFY_DEBUG(bits > 0 && bits <= BitsSize<T>());

        auto [index, used] = SplitBitIndex(Pos_);
        Advance(bits);

        if (Y_LIKELY(bits < BitsSize<T>())) {
            value = LowerBits(value, bits);
        }

        ui8 prev = Data()[index];
        // XXX: msb bits can be wiped and it's OK
        Data()[index] = LowerBits(prev, used) | (ui8(value) << used);

        for (T x = value >> (8 - used); x; x >>= 8) {
            // XXX: msb bits can be wiped and it's OK
            Data()[++index] = ui8(x);
        }
    }

    ui8* Data() noexcept {
        return Storage_->Data();
    }

    void Advance(size_t bitsToAdd) {
        Pos_ += bitsToAdd;
        if (Pos_ >= Storage_->Size()) {
            Storage_->Resize(Pos_);
        }
    }

private:
    TBitBuffer* Storage_;
    size_t Pos_;
};

} // namespace NSolomon::NTs
