#include "helpers.h"

#include <solomon/libs/cpp/shard_metrics/repository.h>

#include <library/cpp/testing/gtest/gtest.h>

#include <latch>
#include <thread>

using namespace NSolomon;
using namespace NMonitoring;

class TMyCoolMetrics: public TResourceUsageContextBase {
public:
    using TResourceUsageContextBase::TResourceUsageContextBase;
    using TResourceUsageContextBase::MemBytes_;

    void Init(NMonitoring::IMetricFactory& f) override {
        Calls_ = f.IntGauge(MakeLabels({{"sensor", "calls"}}));
        Rps_ = f.Rate(MakeLabels({{"sensor", "rps"}}));
    }

    void SetCalls(i64 val) {
        Calls_->Set(val);
    }

    void AddRps(i64 val) {
        Rps_->Add(val);
    }

private:
    IIntGauge* Calls_;
    IRate* Rps_;
};

struct TMyCoolRepository: public TRepository<TMyCoolMetrics> {
};

class TRepositoryTest: public ::testing::Test {
protected:
    auto LabelFactory(TString project, TString shard) {
        return [=] (auto name) {
            return TLabels{{"sensor", name}, {"projectId", project}, {"shardId", shard}};
        };
    }

    void SetUp() override {
        Repo_.Reset(new TMyCoolRepository);
        Consumer_ = {};
    }

    THolder<TMyCoolRepository> Repo_;
    TCollectingConsumer Consumer_;
};

TEST_F(TRepositoryTest, ContextIsCreated) {
    auto ctx = Repo_->GetContext("foo", "bar");
    ctx->SetCalls(42);
    ctx->SetMemoryBytes(41);

    Repo_->Accept(TInstant::Zero(), &Consumer_);

    // 4 from base, 2 from derived, 4+2 aggregated by project, 4+2 grand total
    ASSERT_EQ(Consumer_.Metrics.size(), 18u);
}

TEST_F(TRepositoryTest, ProducesCorrectAggregates) {
    auto ctx1 = Repo_->GetContext("project1", "shard1");
    ctx1->SetCalls(2);
    ctx1->SetMemoryBytes(1);

    auto ctx2 = Repo_->GetContext("project1", "shard2");
    ctx2->SetCalls(4);
    ctx2->SetMemoryBytes(2);

    auto ctx3 = Repo_->GetContext("project2", "shard2");
    ctx3->SetCalls(8);
    ctx3->SetMemoryBytes(4);

    Repo_->Accept(TInstant::Zero(), &Consumer_);

    {
        auto labels = LabelFactory("project1", "total");
        TVector<TMetricData> expected;
        expected.emplace_back(labels("calls"), EMetricType::IGAUGE, SinglePointSeries<ui64>(6));
        expected.emplace_back(labels("rps"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        expected.emplace_back(labels("memBytes"), EMetricType::IGAUGE, SinglePointSeries<ui64>(3));
        expected.emplace_back(labels("network.rxBytes"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        expected.emplace_back(labels("network.txBytes"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        expected.emplace_back(labels("cpu.userTimeMillis"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        ASSERT_THAT(Consumer_.Metrics, MetricsIsSuperset(expected));
    }

    {
        auto labels = LabelFactory("project2", "total");
        TVector<TMetricData> expected;
        expected.emplace_back(labels("calls"), EMetricType::IGAUGE, SinglePointSeries<ui64>(8));
        expected.emplace_back(labels("rps"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        expected.emplace_back(labels("memBytes"), EMetricType::IGAUGE, SinglePointSeries<ui64>(4));
        expected.emplace_back(labels("network.rxBytes"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        expected.emplace_back(labels("network.txBytes"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        expected.emplace_back(labels("cpu.userTimeMillis"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        ASSERT_THAT(Consumer_.Metrics, MetricsIsSuperset(expected));
    }

    {
        auto labels = LabelFactory("total", "total");
        TVector<TMetricData> expected;
        expected.emplace_back(labels("calls"), EMetricType::IGAUGE, SinglePointSeries<ui64>(14));
        expected.emplace_back(labels("rps"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        expected.emplace_back(labels("memBytes"), EMetricType::IGAUGE, SinglePointSeries<ui64>(7));
        expected.emplace_back(labels("network.rxBytes"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        expected.emplace_back(labels("network.txBytes"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        expected.emplace_back(labels("cpu.userTimeMillis"), EMetricType::RATE, SinglePointSeries<ui64>(0));
        ASSERT_THAT(Consumer_.Metrics, MetricsIsSuperset(expected));
    }
}

TEST_F(TRepositoryTest, AcceptIsThreadSafe) {
    TVector<std::shared_ptr<TMyCoolMetrics>> ctxs;
    std::shared_ptr<TMyCoolMetrics> ctx;

    ctx = ctxs.emplace_back(Repo_->GetContext("project1", "shard1"));
    ctx->SetCalls(2);
    ctx->SetMemoryBytes(1);
    ctx->AddNetworkRxBytes(100);
    ctx->AddNetworkTxBytes(99);

    ctx = ctxs.emplace_back(Repo_->GetContext("project1", "shard2"));
    ctx->SetCalls(4);
    ctx->SetMemoryBytes(2);
    ctx->AddCpuTime(TDuration::MilliSeconds(4));

    ctx = ctxs.emplace_back(Repo_->GetContext("project2", "shard2"));
    ctx->AddRps(15);
    ctx->SetMemoryBytes(4);

    constexpr auto workerCount = 10;
    TVector<std::thread> workers;
    std::latch startBarrier{workerCount};

    using TWorkerResult = TVector<TVector<TMetricData>>;
    TVector<TWorkerResult> results;
    results.resize(workerCount);

    for (auto i = 0; i < workerCount; ++i) {
        workers.emplace_back([&, i] {
            startBarrier.arrive_and_wait();

            TWorkerResult runs;
            for (auto j = 0; j < 5; ++j) {
                TCollectingConsumer c;
                Repo_->Accept(TInstant::Zero(), &c);
                runs.push_back(std::move(c.Metrics));
            }

            results[i] = std::move(runs);
        });
    }

    for (auto& w: workers) {
        w.join();
    }

    auto&& expected = results[0][0];
    for (auto&& r: results) {
        for (auto&& val: r) {
            ASSERT_THAT(val, MetricsEq(expected));
        }
    }
}

TEST_F(TRepositoryTest, GetContextIsThreadSafe) {
    constexpr auto workerCount = 10;
    constexpr auto iterCount = 50;

    TVector<std::thread> workers;
    std::latch startBarrier{workerCount};

    for (auto i = 0; i < workerCount; ++i) {
        workers.emplace_back([&, i] {
            startBarrier.arrive_and_wait();

            for (auto j = 0; j < iterCount; ++j) {
                auto ctx = Repo_->GetContext("project", "shard");

                auto prjId = TStringBuilder() << "project" << j;
                auto shardId = TStringBuilder() << "shard" << i;

                auto ctx2 = Repo_->GetContext(prjId, shardId);
                auto ctx3 = Repo_->GetContext(TStringBuilder() << "project" << j, "const");

                Y_UNUSED(ctx, ctx2, ctx3);
                if (j % 10 == 0) {
                    TCollectingConsumer c;
                    Repo_->Accept(TInstant::Zero(), &c);
                }
            }
        });
    }

    for (auto& w: workers) {
        w.join();
    }
}

TEST_F(TRepositoryTest, ContextGetsRemovedWithLastRef) {
    auto ctx = Repo_->GetContext("foo", "bar");

    {
        auto ctx2 = Repo_->GetContext("foo", "bar");
        ASSERT_EQ(ctx.get(), ctx2.get());
    }

    Repo_->Accept(TInstant::Zero(), &Consumer_);
    ASSERT_EQ(Consumer_.Metrics.size(), 18u);
    ctx.reset();
    Consumer_ = {};
    Repo_->Accept(TInstant::Zero(), &Consumer_);
    ASSERT_EQ(Consumer_.Metrics.size(), 12u);
}

TEST_F(TRepositoryTest, ContextIsRecreated) {
    auto ctx = Repo_->GetContext("foo", "bar");
    ctx.reset();
    ctx = Repo_->GetContext("foo", "bar");
    Repo_->Accept(TInstant::Zero(), &Consumer_);
    ASSERT_EQ(Consumer_.Metrics.size(), 18u);
}

TEST_F(TRepositoryTest, ContextCanOutliveRepository) {
    auto ctx = Repo_->GetContext("foo", "bar");
    Repo_.Reset();

    ctx->SetCalls(1000);
}

TEST_F(TRepositoryTest, VisitorHappyPath) {
    struct TShardInfo {
        TShardInfo(TStringBuf prjId, TStringBuf shardId, i64 value = 0)
            : ProjectId{prjId}
            , ShardId{shardId}
            , Value{value}
        {
        }

        bool operator==(const TShardInfo& other) const {
            return std::tie(ProjectId, ShardId) == std::tie(other.ProjectId, other.ShardId);
        }

        TString ProjectId;
        TString ShardId;
        i64 Value;
    };

    class TVisitor: public IRepositoryVisitor {
    public:
        explicit TVisitor(TVector<TShardInfo>& visited)
            : Visited_{visited}
        {
        }

    private:
        void OnShard(TResourceUsageContextHolder& shardCtx) override {
            Visited_.emplace_back(
                    shardCtx.ProjectId,
                    shardCtx.ShardId,
                    dynamic_cast<const TMyCoolRepository::TCtx*>(shardCtx.Impl())->MemBytes_->Get());
        }

        void OnProject(TImmutableMetricRegistry&) override {
            // Already tested in ProducesCorrectAggregates
        }

        void OnTotal(TImmutableMetricRegistry&) override {
            // Already tested in ProducesCorrectAggregates
        }

    private:
        TVector<TShardInfo>& Visited_;
    };

    TVector<TShardInfo> Visited;

    auto ctx1Ptr = Repo_->GetContext("project1", "shard1");
    ctx1Ptr->SetMemoryBytes(1);
    auto ctx2Ptr = Repo_->GetContext("project1", "shard2");
    ctx2Ptr->SetMemoryBytes(2);
    auto ctx3Ptr = Repo_->GetContext("project2", "shard3");
    ctx3Ptr->SetMemoryBytes(3);

    TVisitor visitor{Visited};
    Repo_->Visit(visitor);

    ASSERT_EQ(Visited.size(), 3u);
    ASSERT_THAT(Visited, testing::UnorderedElementsAre(
            TShardInfo{"project1", "shard1"},
            TShardInfo{"project1", "shard2"},
            TShardInfo{"project2", "shard3"}
    ));

    for (const auto& element: Visited) {
        ASSERT_GE(element.Value, 0u);
    }
}
