#include <infra/netmon/statistics/chunked_stream.h>

#include <library/cpp/testing/unittest/registar.h>

#include <util/generic/xrange.h>

using namespace NNetmon;

class TChunkedStreamTest: public TTestBase {
    UNIT_TEST_SUITE(TChunkedStreamTest);
    UNIT_TEST(TestStream)
    UNIT_TEST(TestStableSequence)
    UNIT_TEST(TestChangingSequence)
    UNIT_TEST(TestSerialization)
    UNIT_TEST(TestEmpty)
    UNIT_TEST(TestCopy)
    UNIT_TEST_SUITE_END();

private:
    inline void TestEmpty() {
        UNIT_ASSERT_STRINGS_EQUAL(TChunkedInputStream(nullptr).ReadAll(), "");

        TChunkedOutputStream output;
        UNIT_ASSERT_STRINGS_EQUAL(TChunkedInputStream(&output).ReadAll(), "");
    }

    inline void TestStream() {
        AssertEqualStreams(0);
        AssertEqualStreams(1);

        AssertEqualStreams(256 - 1);
        AssertEqualStreams(256);
        AssertEqualStreams(256 + 1);

        AssertEqualStreams(256 * 2 - 1);
        AssertEqualStreams(256 * 2);
        AssertEqualStreams(256 * 2 + 1);
    }

    inline void AssertEqualStreams(size_t length) {
        TString data;
        for (const auto idx : xrange(length)) {
            data.append(static_cast<const char>(idx));
        }

        TChunkedOutputStream output;
        output.Write(data.data(), data.size());

        TChunkedInputStream input(&output);
        UNIT_ASSERT_STRINGS_EQUAL(input.ReadAll(), data);
    }

    inline void TestStableSequence() {
        TStringBuf data("123456");

        TChunkedOutputStream output;

        const size_t steps = 256 / data.size() + 1;
        for (const auto idx : xrange(steps)) {
            Y_UNUSED(idx);
            output.Write(data.data(), data.size());
        }

        TChunkedInputStream input(&output);
        for (const auto idx : xrange(steps)) {
            Y_UNUSED(idx);
            char buf[data.size()];
            UNIT_ASSERT_VALUES_EQUAL(input.Read(buf, data.size()), data.size());
            UNIT_ASSERT_VALUES_EQUAL(TStringBuf(buf, data.size()), data);
        }
    }

    inline void TestChangingSequence() {
        TStringBuf data("12345678");
        TVector<size_t> sizes{1, 8, 1, 7, 7, 1, 7, 7, 1, 7, 7, 1, 7, 7, 1, 7, 7, 1, 7, 7, 1, 7, 7, 1, 7, 7, 1, 7, 7, 1};

        TChunkedOutputStream output;
        for (const auto size : sizes) {
            UNIT_ASSERT(data.size() >= size);
            output.Write(data.data(), size);
        }

        TChunkedInputStream input(&output);
        for (const auto size : sizes) {
            char buf[size];
            UNIT_ASSERT_VALUES_EQUAL(input.Read(buf, size), size);
            UNIT_ASSERT_VALUES_EQUAL(TStringBuf(buf, size), TStringBuf(data.data(), size));
        }
    }

    inline void TestSerialization() {
        TString data;
        for (const auto idx : xrange(256 * 64 + 1)) {
            data.append(static_cast<const char>(idx));
        }

        TChunkedOutputStream source;
        source.Write(data.data(), data.size());

        flatbuffers::FlatBufferBuilder builder;
        builder.Finish(source.ToProto(builder));

        const NCommon::TChunkedStream& dump(
            *flatbuffers::GetRoot<NCommon::TChunkedStream>(builder.GetBufferPointer())
        );
        TChunkedOutputStream target;
        target.FromProto(dump);

        UNIT_ASSERT_VALUES_EQUAL(source.Size(), target.Size());

        TChunkedInputStream stream(&target);
        UNIT_ASSERT_VALUES_EQUAL(stream.ReadAll(), data);
    }

    inline void TestCopy() {
        TChunkedOutputStream source;
        for (const auto idx : xrange(256 * 2 + 1)) {
            source.Write(static_cast<const char>(idx));
        }

        TChunkedOutputStream target(source);
        UNIT_ASSERT_VALUES_EQUAL(source.Size(), target.Size());
        UNIT_ASSERT_VALUES_EQUAL(
            TChunkedInputStream(&source).ReadAll(),
            TChunkedInputStream(&target).ReadAll()
        );

        target.Write('x');
        UNIT_ASSERT(source.Size() != target.Size());
    }
};

UNIT_TEST_SUITE_REGISTRATION(TChunkedStreamTest);

