#pragma once

#include "bits.h"
#include "bit_span.h"

#include <util/generic/buffer.h>
#include <util/generic/cast.h>

#include <optional>

#ifndef _little_endian_
#  error "unsupported platform"
#endif

namespace NSolomon::NTs {

/**
 * Bit-level stream reader. Values read from a stream treated as LE-ordered.
 *
 * NOTE: All methods do not check out of range reads at runtime for performance reason.
 *       So before performing read you must be sure that there are enough bits left.
 */
class TBitReader {
public:
    /**
     * Initialize bit reader.
     * IMPORTANT: if sizeBits is not multiple of 8, then msb bits in the last byte must be cleared
     *
     * @param data      pointer to memory block start
     * @param sizeBits  block size in bits
     */
    TBitReader(const void* data, size_t sizeBits) noexcept
        : Data_{reinterpret_cast<const ui8*>(data)}
        , Size_{sizeBits}
        , Pos_{0}
    {
    }

    TBitReader(TBitSpan span) noexcept
        : TBitReader(span.Data(), span.Size())
    {
    }

    /**
     * Read single bit and advance the current position forward by 1 bit.
     *
     * @return {@code true} iff bit was {@code 1}, false otherwise
     */
    bool ReadBit() noexcept {
        Y_VERIFY_DEBUG(Left() >= 1);
        auto [index, used] = SplitBitIndex(Pos_++);
        return Data_[index] & (1u << used);
    }

    /**
     * Read ui8 value and advance the current position forward by 8 bits.
     *
     * @return ui8 value
     */
    ui8 ReadInt8() noexcept {
        Y_VERIFY_DEBUG(Left() >= BitsSize<ui8>());

        auto [index, used] = SplitBitIndex(Pos_);
        Pos_ += BitsSize<ui8>();

        if (Y_UNLIKELY(used == 0)) {
            return Data_[index];
        }

        ui16 x = ReadUnaligned<ui16>(Data_ + index);
        return static_cast<ui8>(x >> used);
    }

    /**
     * Read ui32 value and advance the current position forward by 32 bits.
     *
     * @return ui32 value
     */
    ui32 ReadInt32() noexcept {
        Y_VERIFY_DEBUG(Left() >= BitsSize<ui32>());

        auto [index, used] = SplitBitIndex(Pos_);
        Pos_ += BitsSize<ui32>();

        if (Y_UNLIKELY(used == 0)) {
            return ReadUnaligned<ui32>(Data_ + index);
        }

        ui64 x = 0;
        std::memcpy(&x, Data_ + index, sizeof(ui32) + 1);
        return static_cast<ui32>(x >> used);
    }

    /**
     * Read ui64 value and advance the current position forward by 64 bits.
     *
     * @return ui64 value
     */
    ui64 ReadInt64() noexcept {
        Y_VERIFY_DEBUG(Left() >= BitsSize<ui64>());

        auto [index, used] = SplitBitIndex(Pos_);
        Pos_ += BitsSize<ui64>();

        if (Y_UNLIKELY(used == 0)) {
            return ReadUnaligned<ui64>(Data_ + index);
        }

        ui64 x[2] = {0};
        std::memcpy(&x, Data_ + index, sizeof(ui64) + 1);
        return (x[0] >> used) | (x[1] << (BitsSize<ui64>() - used));
    }

    /**
     * Read next 8 bytes as double and advance the current position forward by 64 bits.
     *
     * @return double value
     */
    double ReadDouble() noexcept {
        return BitCast<double>(ReadInt64());
    }

    /**
     * Read up to {@code bits} bits into ui8 value and advance the current position forward.
     *
     * @param bits  maximum number of bits to read (must be <= 8)
     * @return ui8 value
     */
    ui8 ReadInt8(size_t bits) noexcept {
        Y_VERIFY_DEBUG(Left() >= bits);
        Y_VERIFY_DEBUG(bits <= BitsSize<ui8>());

        auto [beginIndex, used] = SplitBitIndex(Pos_);  // inclusive
        size_t endIndex = ByteCount(Pos_ + bits);       // exclusive
        Pos_ += bits;

        Y_VERIFY_DEBUG(endIndex <= ByteCount(Size_), "endIndex(%zu) > size(%zu)", endIndex, ByteCount(Size_));

        ui16 x = 0;
        std::memcpy(&x, Data_ + beginIndex, endIndex - beginIndex);
        return static_cast<ui8>(LowerBits(x, bits + used) >> used);
    }

    /**
     * Read up to {@code bits} bits into ui32 value and advance the current position forward.
     *
     * @param bits  maximum number of bits to read (must be <= 32)
     * @return ui32 value
     */
    ui32 ReadInt32(size_t bits) noexcept {
        Y_VERIFY_DEBUG(Left() >= bits);
        Y_VERIFY_DEBUG(bits <= BitsSize<ui32>());

        auto [beginIndex, used] = SplitBitIndex(Pos_);  // inclusive
        size_t endIndex = ByteCount(Pos_ + bits);       // exclusive
        Pos_ += bits;

        Y_VERIFY_DEBUG(endIndex <= ByteCount(Size_), "endIndex(%zu) > size(%zu)", endIndex, ByteCount(Size_));

        ui64 x = 0;
        std::memcpy(&x, Data_ + beginIndex, endIndex - beginIndex);
        return static_cast<ui32>(LowerBits(x, bits + used) >> used);
    }

    /**
     * Read up to {@code bits} bits into ui64 value and advance the current position forward.
     *
     * @param bits  maximum number of bits to read (must be <= 64)
     * @return ui64 value
     */
    ui64 ReadInt64(size_t bits) noexcept {
        Y_VERIFY_DEBUG(Left() >= bits);
        Y_VERIFY_DEBUG(bits <= BitsSize<ui64>());

        auto [beginIndex, used] = SplitBitIndex(Pos_);  // inclusive
        size_t endIndex = ByteCount(Pos_ + bits);       // exclusive
        Pos_ += bits;

        Y_VERIFY_DEBUG(endIndex <= ByteCount(Size_), "endIndex(%zu) > size(%zu)", endIndex, ByteCount(Size_));

        if (Y_UNLIKELY(used == 0)) {
            ui64 x = 0;
            std::memcpy(&x, Data_ + beginIndex, endIndex - beginIndex);
            return Y_UNLIKELY(bits == BitsSize<ui64>()) ? x : LowerBits(x, bits);
        }

        ui64 x[2] = {0, 0};
        std::memcpy(&x, Data_ + beginIndex, endIndex - beginIndex);
        ui64 combined = (x[0] >> used) | (x[1] << (BitsSize<ui64>() - used));
        return Y_UNLIKELY(bits == BitsSize<ui64>()) ? combined : LowerBits(combined, bits);
    }

    /**
     * Read ui32 value stored in LEB128 (@see https://en.wikipedia.org/wiki/LEB128) format and advance
     * the current position up to 10 bytes forward.
     *
     * @return empty optional if there is no enough data to read from stream or data is corrupted,
     *         ui32 value otherwise
     */
    std::optional<ui32> ReadVarInt32() noexcept;

    /**
     * Read ui64 value stored in LEB128 (@see https://en.wikipedia.org/wiki/LEB128) format and advance
     * the current position up to 10 bytes forward.
     *
     * @return empty optional if there is no enough data to read from stream or data is corrupted,
     *         ui64 value otherwise
     */
    std::optional<ui64> ReadVarInt64() noexcept;

    /**
     * Read ui32 value storead as variable length number in prefixed format and advance the current position
     * up to 5 bytes.
     *
     * @return empty optional if there is no enough data to read from stream or data is corrupted,
     *         ui32 value otherwise
     */
    std::optional<ui64> ReadVarInt32Mode() noexcept;

    /**
     * Read ui64 value storead as variable length number in prefixed format and advance the current position
     * up to 7 bytes.
     *
     * @return empty optional if there is no enough data to read from stream or data is corrupted,
     *         ui64 value otherwise
     */
    std::optional<ui64> ReadVarInt64Mode() noexcept;

    /**
     * Read continuous sequence of 1 bits up to {@code max} bits.
     *
     * @param max  maximum number of bits to read (must be <= 8).
     * @return how many bits were read.
     */
    ui8 ReadOnes(ui8 max) noexcept;

    /**
     * @return current bit position.
     */
    size_t Pos() const noexcept {
        return Pos_;
    }

    /**
     * Update current bit position.
     *
     * @param pos  new position (must be <= bitSize provided in constructor).
     */
    void SetPos(size_t pos) noexcept {
        Y_VERIFY_DEBUG(pos <= Size_, "pos(%zu) > size(%zu)", pos, Size_);
        Pos_ = pos;
    }

    /**
     * @return how many bits are left to read.
     */
    size_t Left() const noexcept {
        return Size_ - Pos_;
    }

    /**
     * @return underlying buffer size in bits.
     */
    size_t Size() const noexcept {
        return Size_;
    }

private:
    const ui8* Data_;
    const size_t Size_;
    size_t Pos_;
};

} // namespace NSolomon::NTs
