#include <solomon/libs/cpp/multi_shard/internal/message.h>
#include <solomon/libs/cpp/multi_shard/multi_shard.h>
#include <solomon/libs/cpp/multi_shard/proto/multi_shard.pb.h>
#include <solomon/libs/cpp/testing/matchers.h>

#include <library/cpp/monlib/consumers/collecting_consumer.h>
#include <library/cpp/monlib/encode/spack/spack_v1.h>
#include <library/cpp/monlib/metrics/histogram_snapshot.h>
#include <library/cpp/monlib/metrics/labels.h>
#include <library/cpp/testing/gtest/gtest.h>

#include <util/generic/algorithm.h>
#include <util/random/random.h>
#include <util/stream/mem.h>
#include <util/stream/str.h>

#include <google/protobuf/message.h>

// TODO: split these test suites into two files

using namespace testing;
using namespace NMonitoring;
using namespace NSolomon;
using namespace NSolomon::NMultiShard;

void WriteLabels(IMetricEncoder& encoder, const TLabels& labels) {
    encoder.OnLabelsBegin();

    for (auto&& l: labels) {
        encoder.OnLabel(l.Name(), l.Value());
    }

    encoder.OnLabelsEnd();
}

class TEncoderTest: public testing::Test {
protected:
    void SetUp() override {
        Stream_.Clear();
        Encoder_ = CreateMultiShardEncoder(EFormat::SPACK, Stream_);
    }

    TVector<TMetricData> Decode(const TMessage& message) {
        auto&& data = message.Message().GetData();
        TMemoryInput in{data};
        TCollectingConsumer c;
        DecodeSpackV1(&in, &c);
        return std::move(c.Metrics);
    }

    TStringStream Stream_;
    IMultiShardEncoderPtr Encoder_;
};

TEST_F(TEncoderTest, EmptyData) {
    Encoder_->SetHeader({"token"});
    Encoder_->OnShardBegin("my_project", "my_service", "my_cluster");
    Encoder_->OnShardEnd();

    Encoder_->Close();

    THeaderMessage header;
    Stream_ >> header;
    ASSERT_THAT(header.Message().GetContinuationToken(), StrEq("token"));
    ASSERT_THAT(header.Message().GetFormatVersion(), Eq(1u));

    TMessage message;
    Stream_ >> message;

    ASSERT_THAT(message.Message().GetProject(), StrEq("my_project"));
    ASSERT_THAT(message.Message().GetService(), StrEq("my_service"));
    auto metrics = Decode(message);
    ASSERT_THAT(metrics, IsEmpty());
}

TEST_F(TEncoderTest, MultipleMessages) {
    constexpr auto COUNT = 4;
    Encoder_->SetHeader({"token"});

    for (auto i = 0; i < COUNT; ++i) {
        Encoder_->OnShardBegin(TStringBuilder() << "my_project" << i, "my_service");
        WriteLabels(*Encoder_, {{"my", "label"}});

        Encoder_->OnMetricBegin(EMetricType::COUNTER);
        WriteLabels(*Encoder_, {{"sensor", "something"}});
        Encoder_->OnUint64(TInstant::Zero(), 0);
        Encoder_->OnMetricEnd();

        Encoder_->OnShardEnd();
    }

    Encoder_->Close();

    THeaderMessage header;
    Stream_ >> header;
    ASSERT_THAT(header.Message().GetContinuationToken(), StrEq("token"));
    ASSERT_THAT(header.Message().GetFormatVersion(), Eq(1u));
    TLabels expectedLabels{{"sensor", "something"}};
    for (auto i = 0; i < COUNT; ++i) {
        TMessage message;
        Stream_ >> message;

        TString expectedProject = TStringBuilder() << "my_project" << i;
        ASSERT_THAT(message.Message().GetProject(), StrEq(expectedProject));
        ASSERT_THAT(message.Message().GetService(), StrEq("my_service"));
        auto metrics = Decode(message);
        ASSERT_THAT(metrics.size(), Eq(1u));
        ASSERT_THAT(metrics[0].Kind, Eq(EMetricType::COUNTER));
        ASSERT_THAT(metrics[0].Labels, LabelsEq(expectedLabels));
    }
}

TEST_F(TEncoderTest, ThrowsIfShardIsNotOpen) {
    ASSERT_THROW(Encoder_->OnStreamBegin(), yexception);
    ASSERT_THROW(Encoder_->OnLabelsBegin(), yexception);
    ASSERT_THROW(Encoder_->OnMetricBegin(EMetricType::HIST), yexception);
};

TEST_F(TEncoderTest, ThrowsIfEncoderIsClosed) {
    Encoder_->SetHeader({"token"});
    Encoder_->OnShardBegin("foo", "bar");
    Encoder_->OnStreamBegin();
    WriteLabels(*Encoder_, {{"common", "labels"}});

    Encoder_->OnStreamEnd();
    Encoder_->OnShardEnd();
    Encoder_->Close();
    ASSERT_THROW(WriteLabels(*Encoder_, {{"common", "labels"}}), yexception);
}

TEST_F(TEncoderTest, ThrowsIfPreviousIsNotClosed) {
    Encoder_->SetHeader({"token"});
    Encoder_->OnShardBegin("foo", "bar");
    ASSERT_THROW(Encoder_->OnShardBegin("boo", "far"), yexception);
}

TEST_F(TEncoderTest, TokenIsSet) {
    Encoder_->SetHeader({"foo"});
    Encoder_->OnShardBegin("foo", "bar");
    Encoder_->OnStreamBegin();
    WriteLabels(*Encoder_, {{"common", "labels"}});

    Encoder_->OnStreamEnd();
    Encoder_->OnShardEnd();
    Encoder_->Close();

    THeaderMessage header;
    Stream_ >> header;
    ASSERT_THAT(header.Message().GetContinuationToken(), Eq("foo"));
    ASSERT_THAT(header.Message().GetFormatVersion(), Eq(1u));
}

TEST_F(TEncoderTest, HeaderIsRequired) {
    ASSERT_THROW(
            Encoder_->OnShardBegin("foo", "bar"),
            yexception
    );
}

namespace {
    struct TData {
        bool operator==(const TData& other) const {
            return Project == other.Project
                && Cluster == other.Cluster
                && Service == other.Service
                && Data == other.Data;
        }

        TString Project;
        TString Cluster;
        TString Service;
        TString Data;
    };

    std::ostream& operator<<(std::ostream& os, const TData& d) {
        os << '{' << d.Project << '/' << d.Cluster << '/' << d.Service << "} " << d.Data;
        return os;
    }

    struct TMockHandler: public IMessageHandler {
        TMockHandler() = default;
        ~TMockHandler() {
            Y_VERIFY(IsClosed);
        }

        bool OnShardData(TString project, TString cluster, TString service, TString data) override {
            Shards.push_back({
                .Project = project,
                .Cluster = cluster,
                .Service = service,
                .Data = data,
            });

            return true;
        }

        bool OnHeader(THeader header) override {
            Header = header;
            return true;
        }

        void OnError(TString err) override {
            Error = err;
        }

        void OnStreamEnd() override {
            IsClosed = true;
        }

        std::optional<TString> Error;
        THeader Header;
        TVector<TData> Shards;
        bool IsClosed = false;
    };

} // namespace

template <ui16 Len>
struct TConstGenerator {
    static ui64 Get() {
        return Len;
    }
};
template <ui16 Mod>
struct TRandGenerator {
    static ui64 Get() {
        return RandomNumber<ui32>(Mod);
    }
};

class TDecoderTest: public testing::Test {
protected:
    template <typename TLengthGenerator>
    void StreamRead() {
        auto messages = GenerateMessages();
        auto stream = GenerateRawStream(messages);

        TMockHandler handler;
        auto continuousChunkDecoder = CreateMultiShardContinuousChunkDecoder(handler);

        continuousChunkDecoder->Decode(TStringBuf{stream});

        ASSERT_FALSE(handler.Error);
        ASSERT_THAT(handler.Shards, ElementsAreArray(UnwrapData(messages)));
    }

    TVector<TMessage> GenerateMessages(int sz = 100) {
        TVector<TMessage> messages;
        for (auto i = 0; i < sz; ++i) {
            TShardData sd;
            sd.SetProject(TStringBuilder() << "project" << i);
            sd.SetService(TStringBuilder() << "service" << i);
            sd.SetData(TStringBuilder() << "thisismydata");
            messages.emplace_back(std::move(sd));
        }

        return messages;
    }

    TVector<TString> GenerateBrokenMessagesStream() {
        TStringStream ss;
        TVector<TString> brokenMessages;
        brokenMessages.push_back("hello");
        ui32 len{1};
        ss.Write(&len, sizeof(len));
        ss << "something";
        brokenMessages.push_back(ss.Str());

        auto msgs = GenerateMessages(1);
        auto raw = GenerateRawStream(msgs);
        ss.Clear();
        ss.Write(raw.data(), raw.size() - 2);
        brokenMessages.push_back(ss.Str());

        return brokenMessages;
    }

    TString GenerateRawStream(const TVector<TMessage>& messages) {
        TStringStream ss;

        TMultiShardData::THeader header;
        header.SetContinuationToken("token");
        header.SetFormatVersion(1);
        THeaderMessage message{std::move(header)};
        ss << message;
        for (auto&& m: messages) {
            ss << m;
        }

        return ss.Str();
    }

    TVector<TData> UnwrapData(const TVector<TMessage>& messages) {
        TVector<TData> result;
        for (auto& m: messages) {
            auto&& d = m.Message();
            result.push_back(TData{
                .Project = d.GetProject(),
                .Service = d.GetService(),
                .Data = d.GetData(),
            });
        }

        return result;
    }
};

TEST_F(TDecoderTest, EmptyDataChunkDecoder) {
    TMockHandler handler;
    auto decoder = CreateMultiShardContinuousChunkDecoder(handler);
    decoder->Decode(TStringBuf{});

    ASSERT_THAT(handler.Shards, IsEmpty());
    ASSERT_FALSE(handler.Error);
}

TEST_F(TDecoderTest, StreamRead1) {
    return StreamRead<TConstGenerator<1>>();
}

TEST_F(TDecoderTest, StreamRead10) {
    return StreamRead<TConstGenerator<10>>();
}

TEST_F(TDecoderTest, StreamRead100) {
    return StreamRead<TConstGenerator<100>>();
}

TEST_F(TDecoderTest, StreamReadRand100) {
    return StreamRead<TRandGenerator<100>>();
}

TEST_F(TDecoderTest, StreamReadRand1000) {
    return StreamRead<TRandGenerator<1000>>();
}

TEST_F(TDecoderTest, BrokenMessagesChunkDecoder) {
    TVector<TString> brokenMessages = GenerateBrokenMessagesStream();

    TString continuousChunk;
    for (auto&& msg: brokenMessages) {
        continuousChunk.append(msg.data(), msg.size());
    }

    TMockHandler handler;
    auto decoder = CreateMultiShardContinuousChunkDecoder(handler);
    decoder->Decode(TStringBuf{continuousChunk});
    ASSERT_TRUE(handler.Error);
    ASSERT_THAT(handler.Shards, IsEmpty());
}

TEST_F(TDecoderTest, BrokenTailChunkDecoder) {
    TMockHandler handler;
    auto decoder = CreateMultiShardContinuousChunkDecoder(handler);
    auto messages = GenerateMessages();
    auto stream = GenerateRawStream(messages);

    // cut some bytes to make last message not parseable
    decoder->Decode(TStringBuf{stream.data(), stream.size() - 5});

    ASSERT_TRUE(handler.Error);
    messages.pop_back();
    ASSERT_THAT(handler.Shards, ElementsAreArray(UnwrapData(messages)));
}
