#include <solomon/services/slicer/lib/balancer/slice_load_balancer.h>
#include <solomon/services/slicer/lib/db/service_config.h>

#include <library/cpp/json/writer/json.h>
#include <library/cpp/json/json_value.h>
#include <library/cpp/testing/gtest/gtest.h>

using namespace NSolomon;
using namespace NSolomon::NSlicer;
using namespace NSolomon::NSlicer::NApi;
using namespace NSolomon::NSlicer::NDb;
using namespace NSolomon::NSlicer::NBalancer;

class TSliceBalancerTest: public ::testing::Test {
public:
    void SetUp() override {
    }

    void TearDown() override {
    }

protected:
    TServiceConfig ServiceSettings_{"service-unit-test", "cluster-unit-test", "dc-unit-test"};
};

/**
 * ---------------------------------------------------------------------------------------------------------------------
 */

constexpr ui32 MAX_NUM_ID = Max<TNumId>();
constexpr ui64 MAX_CPU_LOAD = 16ull * 1'000'000'000; // 16 CPU (1 CPU == 10^9 nanoseconds)
constexpr ui64 MAX_MEMORY_LOAD = 64ull * 1024 * 1024 * 1024; // 64 GB

struct TTestShard {
    ui32 Id;
    double Load;
};

using TTestSlice = std::vector<TTestShard>;

struct TTestConfigRecord {
    TString Host;
    TVector<TTestSlice> Slices;
};

using TTestConfig = std::vector<TTestConfigRecord>;

class TTestDataProvider {
public:

    enum ELoadType {
        Cpu,
        Memory
    };

    TTestDataProvider(ELoadType loadType, const TTestConfig& testConfig)
        : LoadType_(loadType)
    {
        const size_t slicesCount = GetSlicesCount(testConfig);
        TVector<TSlice> allSlices = DistributeSlicesEvenly(slicesCount);
        BuildHostSlices(testConfig, allSlices);
        BuildHostSlicesWithShards();
    }

    const TStringMap<TLoadInfo>& HostsInfo() const {
        return HostsInfo_;
    }

    const TStringMap<NApi::TSlices>& HostSlices() const {
        return HostSlices_;
    }

    const TStringMap<TSlicesWithShards>& HostSlicesWithShards() const {
        return HostSlicesWithShards_;
    }

    const absl::flat_hash_map<NApi::TNumId, TLoadInfo>& ShardsInfo() const {
        return ShardsInfo_;
    }

    const absl::flat_hash_map<NApi::TNumId, ui32>& NumidToInternalId() const {
        return NumIdToInternalId_;
    }

private:
    static size_t GetSlicesCount(const TTestConfig& testConfig) {
        size_t slicesCount = 0;
        for (const auto& rec: testConfig) {
            slicesCount += rec.Slices.size();
        }
        return slicesCount;
    }

    static TVector<TSlice> DistributeSlicesEvenly(size_t slicesCount) {
        TVector<TSlice> allSlices;
        if (slicesCount == 0) {
            return allSlices;
        }

        const ui32 sliceSize = MAX_NUM_ID / slicesCount;
        allSlices.reserve(slicesCount);
        for (size_t i = 0; i < slicesCount; i++) {
            allSlices.emplace_back(i * sliceSize, (i + 1) * sliceSize - 1);
        }
        allSlices.back().End = MAX_NUM_ID;
        return allSlices;
    }

    ui64 DistributeShardsEvenly(const TSlice& slice, const TTestSlice& shardLoads) {
        if (shardLoads.empty()) {
            return 0;
        }

        Y_VERIFY(slice.Size() >= shardLoads.size());
        const ui32 shardsSpace = slice.Size() / shardLoads.size();
        ui64 sliceLoad = 0;
        for (size_t i = 0; i < shardLoads.size(); i++) {
            ui32 offset = i * shardsSpace;
            TNumId numId = slice.Start + offset;
            Y_VERIFY(numId <= slice.End);
            Y_VERIFY(!NumIdToInternalId_.contains(numId));
            NumIdToInternalId_[numId] = shardLoads[i].Id;
            auto& loadInfo = ShardsInfo_[numId];
            switch (LoadType_) {
                case Cpu:
                    loadInfo.CpuTimeNanos = static_cast<ui64>(shardLoads[i].Load * MAX_CPU_LOAD);
                    sliceLoad += loadInfo.CpuTimeNanos;
                    break;
                case Memory:
                    loadInfo.MemoryBytes = static_cast<ui64>(shardLoads[i].Load * MAX_MEMORY_LOAD);
                    sliceLoad += loadInfo.MemoryBytes;
                    break;
            }
        }
        return sliceLoad;
    }

    void BuildHostSlices(const TTestConfig& testConfig, const TVector<TSlice>& allSlices) {
        size_t nSlice = 0;
        for (const auto& rec: testConfig) {
            TSlices slices;
            ui64 hostLoad = 0;
            for (const auto& slice: rec.Slices) {
                hostLoad += DistributeShardsEvenly(allSlices[nSlice], slice);
                Y_VERIFY(!slices.contains(allSlices[nSlice]));
                slices.insert(allSlices[nSlice]);
                nSlice++;
            }
            SetHostInfo(rec.Host);
            SetHostSlices(rec.Host, std::move(slices));
            CheckHostLoad(rec.Host, hostLoad);
        }
    }

    void BuildHostSlicesWithShards() {
        Y_VERIFY(SortedNumIds_.empty());
        Y_VERIFY(HostSlicesWithShards_.empty());

        SortedNumIds_.reserve(ShardsInfo_.size());
        for (const auto& [numId, _]: ShardsInfo_) {
            SortedNumIds_.push_back(numId);
        }
        Sort(SortedNumIds_);

        for (const auto& [host, slices]: HostSlices()) {
            auto& slicesWithShards = HostSlicesWithShards_[host];
            for (const auto& slice: slices) {
                auto& sliceShards = slicesWithShards[slice];
                const auto numIds = GetSliceShards(slice);
                for (auto numId: numIds) {
                    sliceShards[numId] = ShardsInfo_[numId];
                }
            }
        }
    }

    void SetHostInfo(const TString& host) {
        Y_VERIFY(!HostsInfo_.contains(host));
        auto& loadInfo = HostsInfo_[host];
        // it is a host budget
        loadInfo.CpuTimeNanos = MAX_CPU_LOAD;
        loadInfo.MemoryBytes = MAX_MEMORY_LOAD;
    }

    void SetHostSlices(const TString& host, TSlices hostSlices) {
        Y_VERIFY(!HostSlices_.contains(host));
        HostSlices_[host] = std::move(hostSlices);
    }

    TVector<NApi::TNumId> GetSliceShards(const NApi::TSlice& slice) const {
        TVector<NApi::TNumId> sliceNumIds;
        auto nIt = std::lower_bound(SortedNumIds_.begin(), SortedNumIds_.end(), slice.Start);
        while (nIt != SortedNumIds_.end() && *nIt <= slice.End) {
            sliceNumIds.emplace_back(*nIt);
            ++nIt;
        }
        return sliceNumIds;
    }

    void CheckHostLoad(const TString& host, ui64 hostLoad) {
        const auto& hostBudget = HostsInfo_[host];
        switch (LoadType_) {
            case Cpu:
                Y_ENSURE(
                        hostLoad <= hostBudget.CpuTimeNanos,
                        TStringBuilder{} << "host " << host << " CPU load " << hostLoad << " is over the budget " << hostBudget.CpuTimeNanos);
                break;
            case Memory:
                Y_ENSURE(
                        hostLoad <= hostBudget.MemoryBytes,
                        TStringBuilder{} << "host " << host << " memory load " << hostLoad << " is over the budget " << hostBudget.MemoryBytes);
                break;
        }
    }

private:
    const ELoadType LoadType_;
    TStringMap<TLoadInfo> HostsInfo_;
    TStringMap<TSlices> HostSlices_;
    TStringMap<TSlicesWithShards> HostSlicesWithShards_;
    absl::flat_hash_map<TNumId, TLoadInfo> ShardsInfo_;
    TVector<TNumId> SortedNumIds_;
    absl::flat_hash_map<TNumId, ui32> NumIdToInternalId_;
};

bool IsEqualBalancing(
        const TStringMap<TSlicesWithShards>& hostToSlices1,
        const absl::flat_hash_map<TNumId, ui32>& numIds1,
        const TStringMap<TSlicesWithShards>& hostToSlices2,
        const absl::flat_hash_map<TNumId, ui32>& numIds2,
        bool ignoreSliceBounds = true)
{
    if (hostToSlices1.size() != hostToSlices2.size()) {
        return false;
    }

    for (const auto& [host, slices1]: hostToSlices1) {
        auto hostIt2 = hostToSlices2.find(host);
        if (hostIt2 == hostToSlices2.end()) {
            return false;
        }

        auto& slices2 = hostIt2->second;
        if (slices1.size() != slices2.size()) {
            return false;
        }

        auto sliceIt1 = slices1.begin();
        auto sliceIt2 = slices2.begin();
        while (sliceIt1 != slices1.end() && sliceIt2 != slices2.end()) {
            if (!ignoreSliceBounds && sliceIt1->first != sliceIt2->first) {
                return false;
            }
            const auto& shards1 = sliceIt1->second;
            const auto& shards2 = sliceIt2->second;
            if (shards1.size() != shards2.size()) {
                return false;
            }
            auto shardIt1 = shards1.begin();
            auto shardIt2 = shards2.begin();
            while (shardIt1 != shards1.end() && shardIt2 != shards2.end()) {
                if (numIds1.at(shardIt1->first) != numIds2.at(shardIt2->first)) {
                    return false;
                }
                if (shardIt1->second != shardIt1->second) {
                    return false;
                }
                ++shardIt1, ++shardIt2;
            }
            ++sliceIt1, ++sliceIt2;
        }
    }
    return true;
}

using namespace NJson;

TString ToJson(const TStringMap<TSlicesWithShards>& hostSlices) {
    TVector<TStringBuf> hosts;
    hosts.reserve(hostSlices.size());
    for (const auto& [host, _]: hostSlices) {
        hosts.emplace_back(host);
    }
    Sort(hosts);

    NJsonWriter::TBuf json(NJsonWriter::HEM_DONT_ESCAPE_HTML);
    json.SetIndentSpaces(2);
    auto root = json.BeginList();
    for (const auto& host: hosts) {
        auto hostObj = root.BeginObject();
        hostObj.WriteKey("Host").WriteString(host);

        auto slicesArr = hostObj.WriteKey("Slices").BeginList();
        const auto& slices = hostSlices.at(host);
        for (const auto& [slice, shards]: slices) {
            auto sliceObj = slicesArr.BeginObject();
            sliceObj.WriteKey("Start")
                .WriteULongLong(slice.Start)
                .WriteKey("End")
                .WriteULongLong(slice.End);

            auto shardsArr = sliceObj.WriteKey("Shards").BeginList();
            for (const auto& [numId, loadInfo]: shards) {
                auto shardObj = shardsArr.BeginObject();
                shardObj.WriteKey("NumId").WriteULongLong(numId);
                shardObj.WriteKey("CpuTimeNanos").WriteULongLong(loadInfo.CpuTimeNanos);
                shardObj.WriteKey("MemoryBytes").WriteULongLong(loadInfo.MemoryBytes);
                shardObj.EndObject();
            }

            shardsArr.EndList();
            sliceObj.EndObject();
        }
        slicesArr.EndList();
        hostObj.EndObject();
    }
    root.EndList();
    return json.Str();
}

const TTestConfig TwoAliveHosts({
        //  Host                Slice1                       Slice2
        {"alive_1",     {  {  {0, 0.01},  {1, 0.02}  },  {  {2, 0.05},  {3, 0.06}  }  }  },
        {"alive_2",     {  {  {4, 0.03},  {5, 0.04}  },  {  {6, 0.07},  {7, 0.08}  }  }  }
});

TEST_F(TSliceBalancerTest, CpuBalance) {
    TTestDataProvider testData(TTestDataProvider::Cpu, TwoAliveHosts);
    TSliceLoadBalancer balancer(ServiceSettings_, testData.HostsInfo(), testData.HostSlices(), testData.ShardsInfo());
    balancer.BalanceByCpu();
}

TEST_F(TSliceBalancerTest, MemoryBalance) {
    TTestDataProvider testData(TTestDataProvider::Memory, TwoAliveHosts);
    TSliceLoadBalancer balancer(ServiceSettings_, testData.HostsInfo(), testData.HostSlices(), testData.ShardsInfo());
    balancer.BalanceByMemory();
}

const TTestConfig MergeTestSource({
        //  Host            Slice1                      Slice2
        {"alive_1",     {   { {0, 0.01}          },  { {1, 0.01} }  }  },
        {"alive_2",     {   { {2, 0.1}, {3, 0.2} }                  }  },
        {"alive_3",     {   { {4, 0.3}, {5, 0.4} }                  }  },
        {"alive_4",     {   { {6, 0.5}, {7, 0.4} }                  }  }
});

const TTestConfig MergeTestResult({
        // slices from host alive_1 merged into slice at host alive_2
        //  Host            Slice1
        {"alive_1",     {                                                 }  },
        {"alive_2",     {   { {0, 0.01}, {1, 0.01}, {2, 0.1}, {3, 0.2} }  }  },
        {"alive_3",     {   { {4, 0.3},   {5, 0.4}  }  }  },
        {"alive_4",     {   { {6, 0.5},   {7, 0.4}  }  }  }
});

TEST_F(TSliceBalancerTest, MergeCpuSlices) {
    TTestDataProvider testData(TTestDataProvider::Cpu, MergeTestSource);
    TServiceConfig mergeSettings{ServiceSettings_};
    mergeSettings.MergeWhenMoreThanNumSlicesPerTask = 0;
    TSliceLoadBalancer balancer(mergeSettings, testData.HostsInfo(), testData.HostSlices(), testData.ShardsInfo());
    balancer.MergeSlices(EReassignmentType::ByCpu);
    const auto hostSlices = balancer.BuildHostToSlicesWithShardsMapping();

    TTestDataProvider sample(TTestDataProvider::Cpu, MergeTestResult);
    const auto expected = sample.HostSlicesWithShards();
    const bool ignoreSliceBounds = true;
    ASSERT_TRUE(IsEqualBalancing(hostSlices, testData.NumidToInternalId(), expected, sample.NumidToInternalId(), ignoreSliceBounds))
        << "Balancing result differs from expected.\n"
        << "Result:\n" << ToJson(hostSlices)
        << "\nExpected:\n" << ToJson(expected);
    ASSERT_TRUE(balancer.CheckSliceLoadValues()) << "slice load values contain wrong values";

    ASSERT_EQ(balancer.GetStatistics().MergeStatus, EMergeStatus::TooFewSlices);
    ASSERT_EQ(balancer.GetStatistics().NumOfMergedSlices, 2ul);
    ASSERT_EQ(balancer.GetStatistics().MergeIterations, 3ul);
}

TEST_F(TSliceBalancerTest, MergeMemorySlices) {
    TTestDataProvider testData(TTestDataProvider::Memory, MergeTestSource);
    TServiceConfig mergeSettings{ServiceSettings_};
    mergeSettings.MergeWhenMoreThanNumSlicesPerTask = 0;
    TSliceLoadBalancer balancer(mergeSettings, testData.HostsInfo(), testData.HostSlices(), testData.ShardsInfo());
    balancer.MergeSlices(EReassignmentType::ByMemory);
    const auto hostSlices = balancer.BuildHostToSlicesWithShardsMapping();

    TTestDataProvider sample(TTestDataProvider::Memory, MergeTestResult);
    const auto expected = sample.HostSlicesWithShards();
    const bool ignoreSliceBounds = true;
    ASSERT_TRUE(IsEqualBalancing(hostSlices, testData.NumidToInternalId(), expected, sample.NumidToInternalId(), ignoreSliceBounds))
        << "Balancing result differs from expected.\n"
        << "Result:\n" << ToJson(hostSlices)
        << "\nExpected:\n" << ToJson(expected);
    ASSERT_TRUE(balancer.CheckSliceLoadValues()) << "slice load values contain wrong values";

    ASSERT_EQ(balancer.GetStatistics().MergeStatus, EMergeStatus::TooFewSlices);
    ASSERT_EQ(balancer.GetStatistics().NumOfMergedSlices, 2ul);
    ASSERT_EQ(balancer.GetStatistics().MergeIterations, 3ul);
}

const TTestConfig MoveTestSource({
        //  Host            Slice1                      Slice2
        {"alive_1",     {   { {0, 0.01},  {1, 0.01} }                           }   },
        {"alive_2",     {   { {2, 0.02},  {3, 0.02} }                           }   },
        {"alive_3",     {   { {4, 0.03},  {5, 0.03} },                          }   },
        {"alive_4",     {   { {6, 0.1},   {7, 0.1}  },  { {8, 0.2},  {9, 0.2} } }   }
});

// shards 6 and 7 moved to host alive_1
const TTestConfig MoveTestResult({
        //  Host            Slice1                      Slice2
        {"alive_1",     {   { {0, 0.01},  {1, 0.01} }, { {6, 0.1},   {7, 0.1} } }   },
        {"alive_2",     {   { {2, 0.02},  {3, 0.02} }                           }   },
        {"alive_3",     {   { {4, 0.03},  {5, 0.03} },                          }   },
        {"alive_4",     {   { {8, 0.2},   {9, 0.2}  }                           }   }
});

TEST_F(TSliceBalancerTest, MoveCpuSlices) {
    TTestDataProvider testData(TTestDataProvider::Cpu, MoveTestSource);
    TSliceLoadBalancer balancer(ServiceSettings_, testData.HostsInfo(), testData.HostSlices(), testData.ShardsInfo());
    balancer.MoveSlices(EReassignmentType::ByCpu);
    const auto hostSlices = balancer.BuildHostToSlicesWithShardsMapping();

    TTestDataProvider sample(TTestDataProvider::Cpu, MoveTestResult);
    const auto expected = sample.HostSlicesWithShards();
    const bool ignoreSliceBounds = true;
    ASSERT_TRUE(IsEqualBalancing(hostSlices, testData.NumidToInternalId(), expected, sample.NumidToInternalId(), ignoreSliceBounds))
        << "Balancing result differs from expected.\n"
        << "Result:\n" << ToJson(hostSlices)
        << "\nExpected:\n" << ToJson(expected);
    ASSERT_TRUE(balancer.CheckSliceLoadValues()) << "slice load values contain wrong values";
}

TEST_F(TSliceBalancerTest, MoveMemorySlices) {
    TTestDataProvider testData(TTestDataProvider::Memory, MoveTestSource);
    TSliceLoadBalancer balancer(ServiceSettings_, testData.HostsInfo(), testData.HostSlices(), testData.ShardsInfo());
    balancer.MoveSlices(EReassignmentType::ByMemory);
    const auto hostSlices = balancer.BuildHostToSlicesWithShardsMapping();

    TTestDataProvider sample(TTestDataProvider::Memory, MoveTestResult);
    const auto expected = sample.HostSlicesWithShards();
    const bool ignoreSliceBounds = true;
    ASSERT_TRUE(IsEqualBalancing(hostSlices, testData.NumidToInternalId(), expected, sample.NumidToInternalId(), ignoreSliceBounds))
        << "Balancing result differs from expected.\n"
        << "Result:\n" << ToJson(hostSlices)
        << "\nExpected:\n" << ToJson(expected);
    ASSERT_TRUE(balancer.CheckSliceLoadValues()) << "slice load values contain wrong values";
}

const TTestConfig SplitTestSource({
    //  Host            Slice1                    Slice2
    {"alive_1",     {  { {0, 0.01}, {1, 0.01} },  { {2, 0.04}, {3, 0.06} }  }  },
    {"alive_2",     {  { {4, 0.01}, {5, 0.01} },  { {6, 0.01}, {7, 0.01} }  }  }
});

const TTestConfig SplitTestResult({
    // slice 2 of host alive_1 is split
    //  Host            Slice1                     Slice2                    Slice3
    {"alive_1",     {  { {0, 0.01}, {1, 0.01} },  { {2, 0.04}             }, { {3, 0.06} }  }  },
    {"alive_2",     {  { {4, 0.01}, {5, 0.01} },  { {6, 0.01},  {7, 0.01} }                 }  }
});

TEST_F(TSliceBalancerTest, SplitCpuHotSlices) {
    TTestDataProvider testData(TTestDataProvider::Cpu, SplitTestSource);
    TSliceLoadBalancer balancer(ServiceSettings_, testData.HostsInfo(), testData.HostSlices(), testData.ShardsInfo());
    balancer.SplitHotSlices(EReassignmentType::ByCpu);
    const auto hostSlices = balancer.BuildHostToSlicesWithShardsMapping();

    TTestDataProvider sample(TTestDataProvider::Cpu, SplitTestResult);
    const auto expected = sample.HostSlicesWithShards();
    const bool ignoreSliceBounds = true;
    ASSERT_TRUE(IsEqualBalancing(hostSlices, testData.NumidToInternalId(), expected, sample.NumidToInternalId(), ignoreSliceBounds))
        << "Balancing result differs from expected.\n"
        << "Result:\n" << ToJson(hostSlices)
        << "\nExpected:\n" << ToJson(expected);
    ASSERT_TRUE(balancer.CheckSliceLoadValues()) << "slice load values contain wrong values";
}

TEST_F(TSliceBalancerTest, SplitMemoryHotSlices) {
    TTestDataProvider testData(TTestDataProvider::Memory, SplitTestSource);
    TSliceLoadBalancer balancer(ServiceSettings_, testData.HostsInfo(), testData.HostSlices(), testData.ShardsInfo());
    balancer.SplitHotSlices(EReassignmentType::ByMemory);
    const auto hostSlices = balancer.BuildHostToSlicesWithShardsMapping();

    TTestDataProvider sample(TTestDataProvider::Memory, SplitTestResult);
    const auto expected = sample.HostSlicesWithShards();
    const bool ignoreSliceBounds = true;
    ASSERT_TRUE(IsEqualBalancing(hostSlices, testData.NumidToInternalId(), expected, sample.NumidToInternalId(), ignoreSliceBounds))
        << "Balancing result differs from expected.\n"
        << "Result:\n" << ToJson(hostSlices)
        << "\nExpected:\n" << ToJson(expected);
    ASSERT_TRUE(balancer.CheckSliceLoadValues()) << "slice load values contain wrong values";
}
