#include <solomon/libs/cpp/ts_codec/bit_reader.h>
#include <solomon/libs/cpp/ts_codec/bit_writer.h>

#include <library/cpp/testing/gtest/gtest.h>

#include <util/random/random.h>
#include <util/generic/vector.h>

using namespace NSolomon::NTs;

/**
 * Writes to stream variable length offset, 100 values and then read those values from stream.
 */
template <typename TValue, typename TWriter, typename TReader>
void DoWriteReadCheck(TWriter writer, TReader reader) {
    SetRandomSeed(17);

    for (size_t offset = 0; offset < BitsSize<TValue>(); ++offset) {
        TBitBuffer buf;
        TBitWriter w{&buf};
        TVector<TValue> expectedValues;

        // fill zero bits of offset
        for (size_t i = 0; i < offset; ++i) {
            w.WriteBit(false);
        }

        // write and remember 100 random value
        for (size_t i = 0; i < 100; ++i) {
            TValue value = RandomNumber<TValue>();
            expectedValues.push_back(value);
            writer(w, value);
        }

        w.Flush();
        TBitReader r{buf};

        // skip zero bits of offset
        for (size_t i = 0; i < offset; ++i) {
            ASSERT_FALSE(r.ReadBit()) << "offset: " << offset;
        }

        // read expected 100 values
        for (TValue expected: expectedValues) {
            TValue value = reader(r);
            ASSERT_EQ(expected, value) << "offset: " << offset;
        }

        ASSERT_EQ(0u, r.Left()) << "offset: " << offset;
    }
}

/**
 * Same as above, but also iterates over value bit size.
 */
template <typename TValue, typename TWriter, typename TReader>
void DoWriteReadBitsCheck(TWriter writer, TReader reader) {
    SetRandomSeed(17);

    for (size_t bits = 1; bits <= BitsSize<TValue>(); ++bits) {
        for (size_t offset = 0; offset < BitsSize<TValue>(); ++offset) {
            TBitBuffer buf;
            TBitWriter w{&buf};
            TVector<TValue> expectedValues;

            // fill zero bits of offset
            for (size_t i = 0; i < offset; ++i) {
                w.WriteBit(false);
            }

            // write and remember 100 random value
            for (size_t i = 0; i < 100; ++i) {
                TValue value = RandomNumber<TValue>();
                expectedValues.push_back(bits == BitsSize<TValue>() ? value : LowerBits(value, bits));
                writer(w, value, bits);
            }

            w.Flush();
            TBitReader r{buf};

            // skip zero bits of offset
            for (size_t i = 0; i < offset; ++i) {
                ASSERT_FALSE(r.ReadBit()) << "bits: " << bits << ", offset: " << offset;
            }

            // read expected 100 values
            for (size_t i = 0; i < expectedValues.size(); i++) {
                TValue expected = expectedValues[i];
                TValue value = reader(r, bits);
                ASSERT_EQ(expected, value) << "bits: " << bits << ", offset: " << offset << ", i: " << i;
            }

            ASSERT_EQ(0u, r.Left()) << "bits: " << bits << ", offset: " << offset;
        }
    }
}

TEST(TBitWriterReaderTest, Bit) {
    SetRandomSeed(17);

    TBitBuffer buf;
    TBitWriter w{&buf};
    TVector<bool> expected;

    for (size_t i = 0; i < 100; ++i) {
        bool bit = (RandomNumber<ui8>(2) == 1);
        expected.push_back(bit);
        w.WriteBit(bit);
    }

    TBitReader r{buf};
    for (bool bit: expected) {
        ASSERT_EQ(bit, r.ReadBit());
    }
}

TEST(TBitWriterReaderTest, Int8) {
    DoWriteReadCheck<ui8>(
            [](TBitWriter& w, ui8 value) { w.WriteInt8(value); },
            [](TBitReader& r) { return r.ReadInt8(); });
}

TEST(TBitWriterReaderTest, Int32) {
    DoWriteReadCheck<ui32>(
            [](TBitWriter& w, ui32 value) { w.WriteInt32(value); },
            [](TBitReader& r) { return r.ReadInt32(); });
}

TEST(TBitWriterReaderTest, Int64) {
    DoWriteReadCheck<ui64>(
            [](TBitWriter& w, ui64 value) { w.WriteInt64(value); },
            [](TBitReader& r) { return r.ReadInt64(); });
}

TEST(TBitWriterReaderTest, Double) {
    DoWriteReadCheck<double>(
            [](TBitWriter& w, double value) { w.WriteDouble(value); },
            [](TBitReader& r) { return r.ReadDouble(); });
}

TEST(TBitWriterReaderTest, Int8Bits) {
    DoWriteReadBitsCheck<ui8>(
            [](TBitWriter& w, ui8 value, size_t bits) { w.WriteInt8(value, bits); },
            [](TBitReader& r, size_t bits) { return r.ReadInt8(bits); });
}

TEST(TBitWriterReaderTest, Int32Bits) {
    DoWriteReadBitsCheck<ui32>(
            [](TBitWriter& w, ui32 value, size_t bits) { w.WriteInt32(value, bits); },
            [](TBitReader& r, size_t bits) { return r.ReadInt32(bits); });
}

TEST(TBitWriterReaderTest, Int64Bits) {
    DoWriteReadBitsCheck<ui64>(
            [](TBitWriter& w, ui64 value, size_t bits) { w.WriteInt64(value, bits); },
            [](TBitReader& r, size_t bits) { return r.ReadInt64(bits); });
}

TEST(TBitWriterReaderTest, VarInt32) {
    DoWriteReadCheck<ui32>(
            [](TBitWriter& w, ui32 value) { w.WriteVarInt32(value); },
            [](TBitReader& r) -> ui32 {
                auto value = r.ReadVarInt32();
                Y_ENSURE(value.has_value());
                return *value;
            });
}

TEST(TBitWriterReaderTest, VarInt64) {
    DoWriteReadCheck<ui64>(
            [](TBitWriter& w, ui64 value) { w.WriteVarInt64(value); },
            [](TBitReader& r) -> ui64 {
                auto value = r.ReadVarInt64();
                Y_ENSURE(value.has_value());
                return *value;
            });
}

TEST(TBitWriterReaderTest, Ones) {
    SetRandomSeed(17);

    for (size_t offset = 0; offset < BitsSize<ui8>(); ++offset) {
        TBitBuffer buf;
        TBitWriter w{&buf};
        TVector<ui8> expected;

        // fill zero bits of offset
        for (size_t i = 0; i < offset; ++i) {
            w.WriteBit(false);
        }

        for (size_t i = 0; i < 100; ++i) {
            ui8 n = RandomNumber<ui8>(7);
            expected.push_back(n);
            w.WriteOnes(n, n + 1);
        }

        TBitReader r{buf};

        // skip zero bits of offset
        for (size_t i = 0; i < offset; ++i) {
            ASSERT_FALSE(r.ReadBit()) << "offset: " << offset;
        }

        for (ui8 value: expected) {
            ASSERT_EQ(value, r.ReadOnes(8)) << "offset: " << offset;
        }
    }
}

TEST(TBitWriterReaderTest, VarInt32Mode) {
    DoWriteReadCheck<ui32>(
            [](TBitWriter& w, ui32 value) { w.WriteVarInt32Mode(value); },
            [](TBitReader& r) -> ui32 {
                auto value = r.ReadVarInt32Mode();
                Y_ENSURE(value.has_value());
                return *value;
            });
}

TEST(TBitWriterReaderTest, VarInt64Mode) {
    DoWriteReadCheck<ui64>(
            [](TBitWriter& w, ui64 value) { w.WriteVarInt64Mode(value); },
            [](TBitReader& r) -> ui64 {
                auto value = r.ReadVarInt64Mode();
                Y_ENSURE(value.has_value());
                return *value;
            });
}
