#include "slice_load_balancer.h"

#include <solomon/libs/cpp/clients/slicer/api_types.h>
#include <solomon/libs/cpp/logging/logging.h>

#include <library/cpp/json/writer/json.h>

#undef Y_VERIFY
#define Y_VERIFY Y_ENSURE

#define BALANCER_TRACE(...) \
    if (ActorContext_) { MON_TRACE_C(*ActorContext_, Balancer, __VA_ARGS__); }
#define BALANCER_WARN(...) \
    if (ActorContext_) { MON_WARN_C(*ActorContext_, Balancer, __VA_ARGS__); }

namespace NSolomon::NSlicer {

using namespace NApi;

/**
 * ---------------------------------------------------------------------------------------------------------------------
 * TLoadInfo extensions
 */

TLoadInfo operator+(const TLoadInfo& l1, const TLoadInfo& l2) {
    return {
            .CpuTimeNanos = l1.CpuTimeNanos + l2.CpuTimeNanos,
            .MemoryBytes = l1.MemoryBytes + l2.MemoryBytes,
            .NetworkBytes = l1.NetworkBytes + l2.NetworkBytes
    };
}

TLoadInfo operator-(const TLoadInfo& l1, const TLoadInfo& l2) {
    Y_VERIFY(l1.CpuTimeNanos >= l2.CpuTimeNanos);
    Y_VERIFY(l1.MemoryBytes >= l2.MemoryBytes);
    Y_VERIFY(l1.NetworkBytes >= l2.NetworkBytes);

    return {
            .CpuTimeNanos = l1.CpuTimeNanos - l2.CpuTimeNanos,
            .MemoryBytes = l1.MemoryBytes - l2.MemoryBytes,
            .NetworkBytes = l1.NetworkBytes - l2.NetworkBytes
    };
}

TLoadInfo operator+=(TLoadInfo& l1, const TLoadInfo& l2) {
    l1.CpuTimeNanos += l2.CpuTimeNanos;
    l1.MemoryBytes += l2.MemoryBytes;
    l1.NetworkBytes += l2.NetworkBytes;
    return l1;
}

TLoadInfo operator-=(TLoadInfo& l1, const TLoadInfo& l2) {
    Y_VERIFY(l1.CpuTimeNanos >= l2.CpuTimeNanos);
    Y_VERIFY(l1.MemoryBytes >= l2.MemoryBytes);
    Y_VERIFY(l1.NetworkBytes >= l2.NetworkBytes);

    l1.CpuTimeNanos -= l2.CpuTimeNanos;
    l1.MemoryBytes -= l2.MemoryBytes;
    l1.NetworkBytes -= l2.NetworkBytes;
    return l1;
}

} // namespace NSolomon::NSlicer

/**
 * ---------------------------------------------------------------------------------------------------------------------
 */

namespace NSolomon::NSlicer::NBalancer {

constexpr ui32 KEY_SPACE_SIZE = Max<TNumId>();
constexpr double LOAD_GRANULARITY = 0.0001;

struct TMoveInfo {
    TStringBuf FromNode{};
};

TString ToJson(const TStringMap<TSlices>& hostToSlices, bool indented) {
    TVector<TStringBuf> hosts;
    hosts.reserve(hostToSlices.size());
    for (const auto& [host, _]: hostToSlices) {
        hosts.emplace_back(host);
    }
    Sort(hosts);

    NJsonWriter::TBuf json(NJsonWriter::HEM_DONT_ESCAPE_HTML);
    if (indented) {
        json.SetIndentSpaces(2);
    }
    auto root = json.BeginList();
    for (const auto& host: hosts) {
        auto hostObj = root.BeginObject();
        hostObj.WriteKey("Host").WriteString(host);

        auto slicesArr = hostObj.WriteKey("Slices").BeginList();
        const auto& slices = hostToSlices.at(host);
        for (const auto& slice: slices) {
            auto sliceObj = slicesArr.BeginObject();
            sliceObj.WriteKey("Start")
                    .WriteULongLong(slice.Start)
                    .WriteKey("End")
                    .WriteULongLong(slice.End);
            sliceObj.EndObject();
        }
        slicesArr.EndList();
        hostObj.EndObject();
    }
    root.EndList();
    return json.Str();
}

bool AreEqual(const TStringMap<TSlices>& hostToSlices1, const TStringMap<TSlices>& hostToSlices2) {
    if (hostToSlices1.size() != hostToSlices2.size()) {
        return false;
    }
    for (const auto& [host, slices]: hostToSlices1) {
        auto it = hostToSlices2.find(host);
        if (it == hostToSlices2.end()) {
            return false;
        }
        if (it->second != slices) {
            return false;
        }
    }
    return true;
}

bool AreEqual(
        const TStringMap<TSlicesWithShards>& hostToSlices1,
        const TStringMap<TSlicesWithShards>& hostToSlices2,
        bool ignoreSliceBounds)
{
    if (hostToSlices1.size() != hostToSlices2.size()) {
        return false;
    }

    for (const auto& [host, slices1]: hostToSlices1) {
        auto hostIt2 = hostToSlices2.find(host);
        if (hostIt2 == hostToSlices2.end()) {
            return false;
        }

        auto& slices2 = hostIt2->second;
        if (slices1.size() != slices2.size()) {
            return false;
        }

        auto sliceIt1 = slices1.begin();
        auto sliceIt2 = slices2.begin();
        while (sliceIt1 != slices1.end() && sliceIt2 != slices2.end()) {
            if (!ignoreSliceBounds && sliceIt1->first != sliceIt2->first) {
                    return false;
            }
            if (sliceIt1->second != sliceIt2->second) {
                return false;
            }
            ++sliceIt1;
            ++sliceIt2;
        }
    }
    return true;
}

/**
 * ---------------------------------------------------------------------------------------------------------------------
 */

namespace {

/**
 * for logging
 */
template <EReassignmentType RType>
TStringBuf GetLoadName() {
    if constexpr (RType == EReassignmentType::ByCpu) {
        return "CPU";
    } else if constexpr (RType == EReassignmentType::ByMemory) {
        return "memory";
    } else if constexpr (RType == EReassignmentType::ByCount) {
        return "by count";
    } else {
        return "unknown";
    }
}

bool AlmostEqual(double v1, double v2, double maxRelativeError) {
    const double maxValue = std::max(std::abs(v1), std::abs(v2));
    if (maxValue == 0) {
        return true;
    }
    const double diff = std::abs(v1 - v2);
    return (diff / maxValue) < maxRelativeError;
}

bool AlmostEqualLoad(double load1, double load2) {
    return AlmostEqual(load1, load2, LOAD_GRANULARITY);
}

} // namespace

/** --------------------------------------------------------------------------------------------------------------------
 * TSliceLoadBalancer
 */

TSliceLoadBalancer::TSliceLoadBalancer(
        const NDb::TServiceConfig& serviceSettings,
        const TStringMap<TLoadInfo>& hostsInfo,
        const TStringMap<NApi::TSlices>& hostSlices,
        const absl::flat_hash_map<NApi::TNumId, TLoadInfo>& shardsInfo,
        const NActors::TActorContext* actorContext,
        const TString& logPrefix)
    : ActorContext_(actorContext)
    , LogPrefix_(logPrefix)
    , ServiceSettings_(serviceSettings)
    , HostsInfo_(hostsInfo)
    , ShardsInfo_(shardsInfo)
    , SortedNumIds_(std::move(BuildSortedNumIds(shardsInfo)))
    , TotalLoad_(CalcTotalLoad(shardsInfo))
    , TotalBudget_(CalcTotalBudget(hostsInfo))
{
    BuildHostToSlices(hostSlices);
}

void TSliceLoadBalancer::BuildHostToSlices(const TStringMap<TSlices>& hostSlices)
{
    HostToSlices_.clear();
    auto slicesLoad = CalcSlicesLoad(hostSlices);

    for (const auto& [host, slices]: hostSlices) {
        auto& s = HostToSlices_[host];
        for (const auto& slice: slices) {
            auto load = slicesLoad[slice];
            s.emplace(slice, load);
        }
    }
}

absl::flat_hash_map<TSlice, TLoadInfo, THash<TSlice>> TSliceLoadBalancer::CalcSlicesLoad(
        const TStringMap<NApi::TSlices>& hostSlices) const
{
    size_t sliceCount = 0;
    for (const auto& [_, slices]: hostSlices) {
        sliceCount += slices.size();
    }

    TVector<TSlice> allSlices;
    allSlices.reserve(sliceCount);
    for (const auto& [_, slices]: hostSlices) {
        for (const auto& slice: slices) {
            allSlices.push_back(slice);
        }
    }
    Sort(allSlices);

    auto curNumId = SortedNumIds_.begin();
    auto curSlice = allSlices.begin();
    absl::flat_hash_map<TSlice, TLoadInfo, THash<TSlice>> slicesLoadInfo;
    slicesLoadInfo.reserve(sliceCount);
    while (curNumId != SortedNumIds_.end() && curSlice != allSlices.end()) {
        if (*curNumId < curSlice->Start) {
            ++curNumId;
            continue;
        }
        if (curSlice->End < *curNumId) {
            ++curSlice;
            continue;
        }

        Y_VERIFY_DEBUG(curSlice->Start <= *curNumId && *curNumId <= curSlice->End);

        const auto& shardLoad = ShardsInfo_.at(*curNumId);
        auto& sliceLoad = slicesLoadInfo[*curSlice];
        sliceLoad += shardLoad;
        ++curNumId;
    }
    Y_VERIFY(curNumId == SortedNumIds_.end());

    return slicesLoadInfo;
}

TVector<TNumId> TSliceLoadBalancer::BuildSortedNumIds(const absl::flat_hash_map<TNumId, TLoadInfo>& shardsInfo) {
    TVector<TNumId> shards;
    shards.reserve(shardsInfo.size());
    for (const auto& [numId, _]: shardsInfo) {
        shards.push_back(numId);
    }
    Sort(shards);
    return shards;
}

TLoadInfo TSliceLoadBalancer::CalcTotalLoad(const absl::flat_hash_map<NApi::TNumId, TLoadInfo>& shardsInfo) {
    TLoadInfo totalLoad{};
    for (const auto& [_, shardLoad]: shardsInfo) {
        totalLoad += shardLoad;
    }
    return totalLoad;
}

TLoadInfo TSliceLoadBalancer::CalcTotalBudget(const TStringMap<TLoadInfo>& hostsInfo) {
    TLoadInfo totalBudget{};
    for (const auto& [_, hostLoad]: hostsInfo) {
        totalBudget += hostLoad;
    }
    return totalBudget;
}

template <EReassignmentType RType>
ui64 TSliceLoadBalancer::GetLoad(const TLoadInfo& loadInfo) {
    if constexpr (RType == EReassignmentType::ByCpu) {
        return loadInfo.CpuTimeNanos;
    } else if constexpr (RType == EReassignmentType::ByMemory) {
        return loadInfo.MemoryBytes;
    } else {
        return 0;
    }
}

TAssignments TSliceLoadBalancer::BuildAssignments() const {
    TAssignments assignments;

    for (const auto& [host, slices]: HostToSlices_) {
        for (const auto& sliceWithLoad: slices) {
            auto it = assignments.try_emplace(sliceWithLoad.Slice, THosts{}).first;
            it->second.emplace_back(host);
        }
    }

    return assignments;
}

void TSliceLoadBalancer::MergeSlices(EReassignmentType type) {
    switch(type) {
        case EReassignmentType::ByCpu:
            MergeAdjacentColdSlices<EReassignmentType::ByCpu>();
            break;
        case EReassignmentType::ByMemory:
            MergeAdjacentColdSlices<EReassignmentType::ByMemory>();
            break;
        case EReassignmentType::ByCount:
            Y_FAIL("Merge slices is not supported in balancing by count");
            break;
    }
}

void TSliceLoadBalancer::MoveSlices(EReassignmentType type) {
    switch(type) {
        case EReassignmentType::ByCpu:
            MoveSlices<EReassignmentType::ByCpu>();
            break;
        case EReassignmentType::ByMemory:
            MoveSlices<EReassignmentType::ByMemory>();
            break;
        case EReassignmentType::ByCount:
            Y_FAIL("Move slices is not supported in balancing by count");
            break;
    }
}

void TSliceLoadBalancer::SplitHotSlices(EReassignmentType type) {
    switch(type) {
        case EReassignmentType::ByCpu:
            SplitHotSlices<EReassignmentType::ByCpu>();
            break;
        case EReassignmentType::ByMemory:
            SplitHotSlices<EReassignmentType::ByMemory>();
            break;
        case EReassignmentType::ByCount:
            Y_FAIL("Split hot slices is not supported in balancing by count");
            break;
    }
}

TStringMap<NApi::TSlices> TSliceLoadBalancer::BuildHostToSlicesMapping() const {
    TStringMap<NApi::TSlices> hostToSlices;
    for (const auto& [host, slices]: HostToSlices_) {
        auto& s = hostToSlices[host];
        for (const auto& sliceWithLoad: slices) {
            s.emplace(sliceWithLoad.Slice);
        }
    }
    return hostToSlices;
}

TStringMap<TSlicesWithShards> TSliceLoadBalancer::BuildHostToSlicesWithShardsMapping() const {
    TStringMap<TSlicesWithShards> hostToSlicesWithShards;
    for (const auto& [host, slices]: HostToSlices_) {
        auto& slicesWithShards = hostToSlicesWithShards[host];
        for (const auto& sliceWithLoad: slices) {
            auto& shards = slicesWithShards[sliceWithLoad.Slice];
            auto sliceNumIds = GetSliceShards(sliceWithLoad.Slice);
            for (auto numId: sliceNumIds) {
                shards[numId] = ShardsInfo_.at(numId);
            }
        }
    }
    return hostToSlicesWithShards;
}

void TSliceLoadBalancer::BalanceByCpu() {
    DoBalance<EReassignmentType::ByCpu>();
}

void TSliceLoadBalancer::BalanceByMemory() {
    DoBalance<EReassignmentType::ByMemory>();
}

template <EReassignmentType RType>
void TSliceLoadBalancer::DoBalance() {
    MergeAdjacentColdSlices<RType>();
    MoveSlices<RType>();
    SplitHotSlices<RType>();
    DescribeDistribution<RType>();
}

template <EReassignmentType RType>
void TSliceLoadBalancer::MergeAdjacentColdSlices() {
    if (HostToSlices_.size() == 1 && HostToSlices_.begin()->second.size() == 1) {
        Stats_.MergeStatus = EMergeStatus::TooFewSlices;
        return;
    }

    if (CorrectSliceLoadValues()) {
        BALANCER_WARN(LogPrefix_ << "slice load values corrected on MergeAdjacentColdSlices");
    }

    auto slices = BuildSlicesWithLoad();
    auto hostToLoad = BuildHostToLoad();
    const double maxTaskLoad = CalcMaxHostLoad<RType>(hostToLoad);

    ui32 keyChurn = 0;
    double keyChurnRatio = 0;
    const double keyChurnRatioBudget = ServiceSettings_.MergeKeyChurn; // in the paper -- 1%

    double meanSliceLoad = static_cast<double>(GetLoad<RType>(TotalLoad_)) / slices.size();
    auto it = slices.begin();
    decltype(it) next;
    EMergeStatus mergeStatus = EMergeStatus::HappyPath;

    while (true) {
        ++Stats_.MergeIterations;

        // 3.(a) in the paper
        if ((slices.size() / HostToSlices_.size()) <= ServiceSettings_.MergeWhenMoreThanNumSlicesPerTask) {
            mergeStatus = EMergeStatus::TooFewSlices;
            break;
        }

        if (keyChurnRatio >= keyChurnRatioBudget) { // 3.(d) in the paper
            mergeStatus = EMergeStatus::KeyChurnExhausted;
            break;
        }

        next = std::next(it);
        if (it == slices.end() || next == slices.end()) {
            break;
        }

        if ((GetLoad<RType>(it->Load) + GetLoad<RType>(next->Load)) >= meanSliceLoad) { // 3.(b) in the paper
            // XXX(ivanzhukov): ">= meanSliceLoad" is as it is in the paper, but wouldn't it make more sense
            // to check for "> meanSliceLoad", since the desirable state for each slice is to have a load == meanSliceLoad?
            ++it;
            continue;
        }

        double nextBudget = GetLoad<RType>(HostsInfo_.at(next->Host));
        double itBudget = GetLoad<RType>(HostsInfo_.at(it->Host));
        double nextPotentialLoad = (GetLoad<RType>(hostToLoad[next->Host]) + GetLoad<RType>(it->Load)) / nextBudget;
        double itPotentialLoad = (GetLoad<RType>(hostToLoad[it->Host]) + GetLoad<RType>(next->Load)) / itBudget;

        decltype(slices)::iterator srcIt;
        decltype(slices)::iterator dstIt;

        if (it->Host == next->Host || nextPotentialLoad < maxTaskLoad) {
            srcIt = it;
            dstIt = next;
        } else if (itPotentialLoad < maxTaskLoad) {
            srcIt = next;
            dstIt = it;
        } else {
            // 3.(c) in the paper
            ++it;
            continue;
        }

        // performing merge
        ++Stats_.NumOfMergedSlices;
        HostToSlices_[srcIt->Host].erase(srcIt->Slice);
        HostToSlices_[dstIt->Host].erase(dstIt->Slice);

        dstIt->Slice.Start = Min(srcIt->Slice.Start, dstIt->Slice.Start);
        dstIt->Slice.End = Max(srcIt->Slice.End, dstIt->Slice.End);
        dstIt->Load += srcIt->Load;

        HostToSlices_[dstIt->Host].emplace(dstIt->Slice, dstIt->Load);
        hostToLoad[srcIt->Host] -= srcIt->Load;
        hostToLoad[dstIt->Host] += srcIt->Load;

        keyChurn += srcIt->Host == dstIt->Host ? 0 : srcIt->Slice.Size();
        keyChurnRatio = static_cast<double>(keyChurn) / KEY_SPACE_SIZE;

        if (srcIt == it) {
            // [a(it), b, ...] -> [ab(it), ...]
            // [a, b(it), c, ...] -> [a(it), bc, ...]
            if (it == slices.begin()) {
                it = slices.erase(it);
            } else {
                it = std::prev(slices.erase(it));
            }
        } else {
            it = std::prev(slices.erase(next));
        }

        meanSliceLoad = static_cast<double>(GetLoad<RType>(TotalLoad_)) / slices.size();
    }

    Stats_.MergeKeyChurn = keyChurnRatio;
    Stats_.MergeStatus = mergeStatus;

    Y_VERIFY(!slices.empty(), "slices cannot be empty after a merge");
}

TVector<TSliceWithLoadAndHost> TSliceLoadBalancer::BuildSlicesWithLoad() const {
    size_t size = 0;
    for (const auto& [host, slices]: HostToSlices_) {
        size += slices.size();
    }
    TVector<TSliceWithLoadAndHost> slicesWithLoad(::Reserve(size));

    for (const auto& [host, slices]: HostToSlices_) {
        for (const auto& slice: slices) {
            slicesWithLoad.emplace_back(slice.Slice, slice.Load, host);
        }
    }

    Sort(slicesWithLoad, [](const TSliceWithLoadAndHost& l, const TSliceWithLoadAndHost& r) { return l.Slice < r.Slice; });

    return slicesWithLoad;
}

TStringMap<TLoadInfo> TSliceLoadBalancer::BuildHostToLoad() const {
    TStringMap<TLoadInfo> hostToLoad;
    hostToLoad.reserve(HostToSlices_.size());
    for (const auto& [host, slices]: HostToSlices_) {
        for (const auto& slice: slices) {
            auto& hostLoad = hostToLoad[host];
            hostLoad += slice.Load;
        }
    }
    return hostToLoad;
}

template <EReassignmentType RType>
double TSliceLoadBalancer::CalcMaxHostLoad(const TStringMap<TLoadInfo>& hostToLoad) const {
    double maxLoad = 0;
    for (const auto& [host, load]: hostToLoad) {
        const i64 maxHostLoad = GetLoad<RType>(HostsInfo_.at(host));
        Y_VERIFY(maxHostLoad != 0);
        maxLoad = Max(maxLoad, static_cast<double>(GetLoad<RType>(load)) / maxHostLoad);
    }
    return maxLoad;
}

template <EReassignmentType RType>
void TSliceLoadBalancer::MoveSlices() {
    ui64 totalLoad = GetLoad<RType>(TotalLoad_);
    ui64 totalBudget = GetLoad<RType>(TotalBudget_);
    if (totalBudget == 0) {
        BALANCER_WARN(LogPrefix_ << "zero total budget of" << GetLoadName<RType>() << " load");
        return;
    }
    double meanNodeLoad = static_cast<double>(totalLoad) / totalBudget;

    if (CorrectSliceLoadValues()) {
        BALANCER_WARN(LogPrefix_ << "slice load values corrected on MoveSlices");
    }

    ui32 keyChurn = 0;
    double keyChurnRatio = 0;
    double keyChurnRatioBudget = ServiceSettings_.MoveKeyChurn; // in the paper -- 9%
    absl::flat_hash_map<decltype(TSlice::Start), TMoveInfo> movedSlices;
    absl::flat_hash_set<TString> filteredNodes;
    TVector<TNodeWithLoad> nodes(::Reserve(HostsInfo_.size()));

    while (keyChurnRatio < keyChurnRatioBudget && filteredNodes.size() < HostsInfo_.size()) {
        ++Stats_.MoveIterations;

        // TODO: do not sort on every iteration. Just sort once and change elements in-place
        FillNodesWithLoads<RType>(filteredNodes, nodes);
        auto [hottestNode, coldestNode] = DetectHottestAndColdestNodes(
                totalLoad,
                totalBudget,
                meanNodeLoad,
                filteredNodes,
                nodes);
        if (hottestNode.Node == coldestNode.Node || AlmostEqualLoad(hottestNode.LoadNorm, coldestNode.LoadNorm)) {
            break;
        }

        std::optional<TMove> bestMove = FindBestMove<RType>(hottestNode, coldestNode, meanNodeLoad);
        if (!bestMove || bestMove->Weight <= 0) {
            // nothing will become better, so stop for now
            break;
        }

        // making the best move
        HostToSlices_[hottestNode.Node].erase(bestMove->Slice);
        HostToSlices_[coldestNode.Node].emplace(bestMove->Slice);

        auto& moveInfo = movedSlices[bestMove->Slice.Slice.Start];
        if (moveInfo.FromNode.empty()) {
            // we move this slice for the first time
            moveInfo.FromNode = hottestNode.Node;

            keyChurn += bestMove->Slice.Slice.Size();
            keyChurnRatio = static_cast<double>(keyChurn) / KEY_SPACE_SIZE;
        } else if (moveInfo.FromNode == coldestNode.Node) {
            // we move this slice right back, so keyChurn is the same as before the move
            keyChurn -= bestMove->Slice.Slice.Size();
            keyChurnRatio = static_cast<double>(keyChurn) / KEY_SPACE_SIZE;

            movedSlices.erase(bestMove->Slice.Slice.Start);
        } // else: we moved this slice again, so key churn is the same as after the first move
    }

    Stats_.MoveKeyChurn = keyChurnRatio;
    Stats_.NumOfMovedSlices = movedSlices.size();
}

/**
 * Search the best slice among hottest node slices to move to coldest node
 */
template <EReassignmentType RType>
std::optional<TMove> TSliceLoadBalancer::FindBestMove(
        const TNodeWithLoad& hottestNode,
        const TNodeWithLoad& coldestNode,
        double meanNodeLoad) const
{
    std::optional<TMove> bestMove;
    double maxWeight = 0;
    for (const auto& sliceWithLoad: HostToSlices_.at(hottestNode.Node)) {
        TMove move{
                .Weight = 0,
                .Slice = sliceWithLoad,
        };
        const ui64 moveLoad = GetLoad<RType>(sliceWithLoad.Load);

        double hottestNodeDistanceFromMeanBefore = std::abs(hottestNode.LoadNorm - meanNodeLoad);
        double hottestNodeLoadNormAfter = static_cast<double>(hottestNode.Load - moveLoad) / hottestNode.Budget;
        double hottestNodeDistanceFromMeanAfter = std::abs(hottestNodeLoadNormAfter - meanNodeLoad);

        double coldestNodeDistanceFromMeanBefore = std::abs(coldestNode.LoadNorm - meanNodeLoad);
        double coldestNodeLoadNormAfter = static_cast<double>(coldestNode.Load + moveLoad) / coldestNode.Budget;
        double coldestNodeDistanceFromMeanAfter = std::abs(coldestNodeLoadNormAfter - meanNodeLoad);

        double hottestNodeReduction = hottestNodeDistanceFromMeanBefore - hottestNodeDistanceFromMeanAfter;
        double coldestNodeReduction = coldestNodeDistanceFromMeanBefore - coldestNodeDistanceFromMeanAfter;

        move.Weight = (hottestNodeReduction + coldestNodeReduction) / move.Slice.Slice.Size();

        if (move.Weight > maxWeight) {
            bestMove = move;
            maxWeight = move.Weight;
        } else if (move.Weight == maxWeight) {
            // for a deterministic result
            if (bestMove && move.Slice.Slice.Start < bestMove.value().Slice.Slice.Start) {
                bestMove = move;
            }
        }
    }
    return bestMove;
}

template <EReassignmentType RType>
void TSliceLoadBalancer::FillNodesWithLoads(
        absl::flat_hash_set<TString>& filtered,
        TVector<TNodeWithLoad>& nodes) const
{
    nodes.clear();
    for (const auto& [node, info]: HostsInfo_) {
        if (filtered.contains(node)) {
            continue;
        }

        auto nodeBudget = GetLoad<RType>(info);
        if (nodeBudget == 0) {
            auto _node = TStringBuf{node};
            BALANCER_WARN(LogPrefix_ << "node \"" << _node << "\" has no " << GetLoadName<RType>() << " load information");
            filtered.insert(node);
            continue;
        }

        ui64 nodeLoad = 0;
        if (auto it = HostToSlices_.find(node); it != HostToSlices_.end()) {
            for (auto& slice: it->second) {
                nodeLoad += GetLoad<RType>(slice.Load);
            }
        }
        double loadNorm = static_cast<double>(nodeLoad) / nodeBudget;
        Y_VERIFY(loadNorm <= 1.0);
        nodes.emplace_back(node, nodeBudget, nodeLoad, loadNorm);
    }
}

template <EReassignmentType RType>
TVector<double> TSliceLoadBalancer::GetNodeLoads() const
{
    const ui64 totalBudget = GetLoad<RType>(TotalBudget_);
    if (totalBudget == 0) {
        return {};
    }

    ui64 hostsCount{0};
    for (const auto& [host, budget]: HostsInfo_) {
        if (GetLoad<RType>(budget) > 0) {
            hostsCount++;
        }
    }
    if (hostsCount == 0) {
        return {};
    }
    const double meanHostBudget = static_cast<double>(totalBudget) / hostsCount;

    TVector<double> nodeLoads(::Reserve(HostToSlices_.size()));
    for (const auto& [host, slices]: HostToSlices_) {
        ui64 nodeLoad = 0;
        for (const auto& slice: slices) {
            nodeLoad += GetLoad<RType>(slice.Load);
        }
        auto it = HostsInfo_.find(host);
        ui64 nodeBudget = it == HostsInfo_.end() ? 0 : GetLoad<RType>(HostsInfo_.at(host));
        double budget = nodeBudget == 0 ? meanHostBudget : static_cast<double>(nodeBudget);
        double relativeNodeLoad = static_cast<double>(nodeLoad) / budget;
        nodeLoads.push_back(relativeNodeLoad);
    }
    return nodeLoads;
}

/**
 * Filter nodes with the only slice which load is greater than mean load.
 * Return the hottest and the coldest nodes among left nodes.
 */
THottestAndColdest TSliceLoadBalancer::DetectHottestAndColdestNodes(
        ui64& totalLoad,
        ui64& totalBudget,
        double& meanLoad,
        absl::flat_hash_set<TString>& filtered,
        TVector<TNodeWithLoad>& nodes) const
{
    Sort(nodes, [](const TNodeWithLoad& left, const TNodeWithLoad& right) {
        if (left.LoadNorm < right.LoadNorm) {
            return true;
        } else if (left.LoadNorm > right.LoadNorm) {
            return false;
        } else {
            // for a deterministic result
            return left.Node < right.Node;
        }
    });

    auto it = nodes.end() - 1;
    while (it != nodes.begin()) {
        auto slicesIt = HostToSlices_.find(it->Node);
        if (slicesIt == HostToSlices_.end()) {
            Y_VERIFY(it->LoadNorm == 0); // guaranteed by FillNodesWithLoads()
            break;
        }

        if (slicesIt->second.size() == 1 && it->LoadNorm >= meanLoad) {
            totalLoad -= it->Load;
            Y_VERIFY(totalLoad >= 0);
            totalBudget -= it->Budget;
            Y_VERIFY(totalBudget > 0);
            Y_VERIFY(totalLoad <= totalBudget);
            meanLoad = static_cast<double>(totalLoad) / totalBudget;
            Y_VERIFY(meanLoad <= 1.0);
            filtered.emplace(it->Node);
            nodes.erase(it--);
        } else {
            --it;
        }
    }

    BALANCER_TRACE(LogPrefix_ << "hottest node: " << nodes.back().Node << '(' << nodes.back().LoadNorm << ')'
            << "; coldest node: " << nodes[0].Node << '(' << nodes[0].LoadNorm << ')');

    return { .Hottest = nodes.back(), .Coldest = nodes[0] };
}

template <EReassignmentType RType>
void TSliceLoadBalancer::SplitHotSlices() {
    if (GetLoad<RType>(TotalLoad_) == 0 || (HostToSlices_.size() == 1 && HostToSlices_.begin()->second.size() == 1)) {
        return;
    }

    if (CorrectSliceLoadValues()) {
        BALANCER_WARN(LogPrefix_ << "slice load values corrected on SplitHotSlices");
    }

    auto slices = BuildSlicesWithLoad();
    size_t slicesBudget = std::count_if(slices.begin(), slices.end(), [](const auto& s){
        return GetLoad<RType>(s.Load) > 0;
    });
    Y_VERIFY(slicesBudget > 0);

    absl::flat_hash_set<TSlice, THash<TSlice>> unsplittable;
    Stats_.SplitIterations = 0;
    Stats_.SplitCount = 0;
    int counter = static_cast<int>(SortedNumIds_.size());
    ui64 totalLoad = GetLoad<RType>(TotalLoad_);

    while(true) {
        ++Stats_.SplitIterations;
        if (!SplitHotSlicesStep<RType>(slices, unsplittable, totalLoad, slicesBudget)) {
            break;
        }
        Y_VERIFY(--counter >= 0, "too many split slices steps");
    }

    for (auto& [_, slices]: HostToSlices_) {
        slices.clear();
    }
    for (const auto& slice: slices) {
        HostToSlices_[slice.Host].emplace(slice.Slice, slice.Load);
    }
}

template <EReassignmentType RType>
bool TSliceLoadBalancer::SplitHotSlicesStep(
        TVector<TSliceWithLoadAndHost>& slices,
        absl::flat_hash_set<TSlice, THash<TSlice>>& unsplittable,
        ui64& totalLoad,
        size_t& slicesBudget)
{
    if (slicesBudget == 0) {
        return false;
    }
    double meanSliceLoad = static_cast<double>(totalLoad) / slicesBudget;

    // 5.(a) in the paper
    auto isHotEnough = [&meanSliceLoad, this](TSliceWithLoadAndHost& slice) {
        auto nTimes = ServiceSettings_.SplitSliceNTimesAsHotAsMean;
        return static_cast<double>(GetLoad<RType>(slice.Load)) > (meanSliceLoad * nTimes);
    };
    // 5.(b) in the paper
    auto areThereTooManySlices = [&slices, this]() {
        double meanSlicesPerHost = static_cast<double>(slices.size()) / HostToSlices_.size();
        return meanSlicesPerHost >= ServiceSettings_.SplitWhenFewerThanNumSlicesPerTask;
    };

    Sort(slices, [](const TSliceWithLoadAndHost& left, const TSliceWithLoadAndHost& right) {
        const auto leftLoad = GetLoad<RType>(left.Load);
        const auto rightLoad = GetLoad<RType>(right.Load);
        if (leftLoad < rightLoad) {
            return true;
        } else if (leftLoad > rightLoad) {
            return false;
        } else {
            // for a deterministic result
            return left.Slice < right.Slice;
        }
    });
    TVector<TNumId> sliceNumIds;
    size_t numOfSplit = 0;
    for (int idx = static_cast<int>(slices.size()) - 1; idx >= 0; idx--) {
        if (!isHotEnough(slices[idx])) {
            break; // because slices are sorted
        }
        if (areThereTooManySlices()) {
            break;
        }
        if (unsplittable.contains(slices[idx].Slice)) {
            continue;
        }

        if (slices[idx].Slice.Start == slices[idx].Slice.End || GetSliceShards<RType>(slices[idx].Slice, sliceNumIds) == 1) {
            unsplittable.emplace(slices[idx].Slice);
            if (--slicesBudget == 0) {
                return false;
            }

            totalLoad -= GetLoad<RType>(slices[idx].Load);
            meanSliceLoad = static_cast<double>(totalLoad) / slicesBudget;
            continue;
        }

        auto newSlice = SplitHotSlice<RType>(slices[idx], sliceNumIds, static_cast<ui64>(meanSliceLoad));
        ++numOfSplit;
        ++Stats_.SplitCount;
        ++slicesBudget;
        slices.emplace_back(newSlice);
    }

    return numOfSplit > 0;
}

template <EReassignmentType RType>
size_t TSliceLoadBalancer::GetSliceShards(const TSlice& slice, TVector<TNumId>& sliceNumIds) const {
    sliceNumIds.clear();
    size_t numOfLoadedNumIds = 0;
    auto nIt = std::lower_bound(SortedNumIds_.begin(), SortedNumIds_.end(), slice.Start);

    while (nIt != SortedNumIds_.end() && *nIt <= slice.End) {
        sliceNumIds.emplace_back(*nIt);
        if (GetLoad<RType>(ShardsInfo_.at(*nIt)) > 0) {
            ++numOfLoadedNumIds;
        }
        ++nIt;
    }
    Y_VERIFY(!sliceNumIds.empty(), "found no numId for a slice");
    return numOfLoadedNumIds;
}

TVector<NApi::TNumId> TSliceLoadBalancer::GetSliceShards(const NApi::TSlice& slice) const {
    TVector<NApi::TNumId> sliceNumIds;
    auto nIt = std::lower_bound(SortedNumIds_.begin(), SortedNumIds_.end(), slice.Start);
    while (nIt != SortedNumIds_.end() && *nIt <= slice.End) {
        sliceNumIds.emplace_back(*nIt);
        ++nIt;
    }
    return sliceNumIds;
}

template <EReassignmentType RType>
TSliceWithLoadAndHost TSliceLoadBalancer::SplitHotSlice(
        TSliceWithLoadAndHost& slice,
        const TVector<TNumId>& sliceNumIds,
        ui64 meanSliceLoad) const
{
    Y_VERIFY(sliceNumIds.size() > 1);
    Y_VERIFY(GetLoad<RType>(slice.Load) >= meanSliceLoad);

    TLoadInfo newSliceLoad{};
    size_t numIdsInNewSlice = 0;
    std::optional<TNumId> splitOn;

    for (auto numId: sliceNumIds) {
        const auto& load = ShardsInfo_.at(numId);

        if (GetLoad<RType>(newSliceLoad) + GetLoad<RType>(load) < meanSliceLoad) {
            newSliceLoad += load;
            ++numIdsInNewSlice;
        } else {
            if (numIdsInNewSlice == 0) {
                numIdsInNewSlice = 1;
                newSliceLoad = load;
                splitOn = numId;
            } else {
                splitOn = numId - 1;
            }
            break;
        }
    }
    Y_VERIFY(splitOn, "trying to split cold slice");

    TLoadInfo oldSliceLoad{};
    for (size_t i = numIdsInNewSlice; i < sliceNumIds.size(); i++) {
        oldSliceLoad += ShardsInfo_.at(sliceNumIds[i]);
    }
    if (oldSliceLoad + newSliceLoad != slice.Load) {
        BALANCER_WARN(LogPrefix_ << "invalid slice load value");
    }

    TSliceWithLoadAndHost newSlice{TSlice{slice.Slice.Start, splitOn.value()}, newSliceLoad, slice.Host};
    slice.Slice.Start = splitOn.value() + 1;
    slice.Load = oldSliceLoad;
    return newSlice;
}

bool TSliceLoadBalancer::CheckSliceLoadValues() const {
    for (const auto& [host, slices]: HostToSlices_) {
        for (const auto& slice: slices) {
            auto shards = GetSliceShards(slice.Slice);
            TLoadInfo sliceLoad{};
            for (auto numId: shards) {
                sliceLoad += ShardsInfo_.at(numId);
            }
            if (sliceLoad != slice.Load) {
                return false;
            }
        }
    }
    return true;
}

bool TSliceLoadBalancer::CorrectSliceLoadValues() {
    bool hasWrongValues = false;
    for (auto& [host, slices]: HostToSlices_) {
        TVector<TSliceWithLoad> newSlices;
        auto sliceIt = slices.begin();
        while (sliceIt != slices.end()) {
            auto shards = GetSliceShards(sliceIt->Slice);
            TLoadInfo sliceLoad{};
            for (auto numId: shards) {
                sliceLoad += ShardsInfo_.at(numId);
            }
            if (sliceLoad != sliceIt->Load) {
                newSlices.emplace_back(*sliceIt);
                hasWrongValues = true;
                newSlices.back().Load = sliceLoad;
                slices.erase(sliceIt);
            } else {
                sliceIt++;
            }
        }
        for (const auto& slice: newSlices) {
            slices.insert(slice);
        }
    }
    return hasWrongValues;
}

template <EReassignmentType RType>
void TSliceLoadBalancer::DescribeDistribution()
{
    Stats_.MeanLoadPercent = 0.0;
    Stats_.LoadRSDPercent = 0.0;
    TVector<double> nodeLoads = GetNodeLoads<RType>();
    if (nodeLoads.empty()) {
        return;
    }

    const double loadSum = std::accumulate(nodeLoads.begin(), nodeLoads.end(), 0.0);
    const double meanLoad = loadSum / nodeLoads.size();
    Stats_.MeanLoadPercent = meanLoad * 100;

    double squareDeviationsSum = 0.0;
    for (auto load: nodeLoads) {
        double diff = load - meanLoad;
        squareDeviationsSum += diff * diff;
    }

    const double dispersion = squareDeviationsSum / nodeLoads.size();
    const double standardDeviation = std::sqrt(dispersion);
    const double relativeStandardDeviation = standardDeviation / meanLoad;
    Stats_.LoadRSDPercent = relativeStandardDeviation * 100;
}

} // namespace NSolomon::NSlicer::NBalancer
