#include "base_algorithm.h"

#include <cmath>

namespace NSrvKernel {

    namespace {
        struct TWrrRngWrapper {
            TWrrRngWrapper()
                : Rng(RandomNumber<ui32>())
            {}

            TWrrRng Rng;
        };

        TWrrRng* GetWrrRng() noexcept {
            return &FastTlsSingletonWithPriority<TWrrRngWrapper, 2>()->Rng;
        }

        size_t AttemptsLimit(const TBackendsGroupWeights& group) noexcept {
            const size_t sz = group.BoundaryAndIndex.size();
            if (sz > 1) {
                return Min<size_t>(200, 1 + static_cast<size_t>(sz / std::log2(sz)));
            }
            return sz;
        }
    }

TAlgorithmWithRemovals::TAlgorithmWithRemovals(IWorkerCtl* process)
    : IAlgorithm(process)
{}

void TAlgorithmWithRemovals::RemoveSelected(IBackend* backend) noexcept {
    Excluded_.insert(backend);
}

void TAlgorithmWithRemovals::Reset() noexcept {
    Excluded_.clear();
}

bool TAlgorithmWithRemovals::IsRemoved(IBackend* backend) const noexcept {
    return Excluded_.contains(backend);
}

TBaseWrrAlgorithm::TBaseWrrAlgorithm(IWorkerCtl* process, TWrrRng* rng)
    : TAlgorithmWithRemovals(process)
    , Rng_(rng)
{}

size_t TBaseWrrAlgorithm::Round(const TBackendsGroupWeights& group) const noexcept {
    TWrrRng* rng = Rng_ ?: GetWrrRng();
    return FindInGroup(group, group.WeightsSum * rng->GenRandReal4());
}

size_t TBaseWrrAlgorithm::Round(const TBackendsGroupWeights& group, double diceRoll) const noexcept {
    Y_VERIFY(diceRoll <= group.WeightsSum);
    return FindInGroup(group, diceRoll);
}

size_t TBaseWrrAlgorithm::FindInGroup(const TBackendsGroupWeights& group, double diceRoll) const noexcept {
    auto it = std::lower_bound(
        group.BoundaryAndIndex.begin(),
        group.BoundaryAndIndex.end(),
        diceRoll,
        [](const std::pair<double, size_t>& backend, double value) {
            return backend.first < value;
        }
    );
    Y_ENSURE(it != group.BoundaryAndIndex.end());
    return it->second;
}

IBackend* TBaseWrrAlgorithm::SelectBackend(const TBackendsGroupWeights& group, std::function<IBackend*(size_t, double)> resolve) const noexcept {
    if (group.WeightsSum == 0) {
        return nullptr;
    }
    TWrrRng* rng = Rng_ ?: GetWrrRng();
    const size_t limit = AttemptsLimit(group);
    for (size_t i = 0; i < limit; i++) {
        double diceRoll = group.WeightsSum * rng->GenRandReal4();
        size_t index = FindInGroup(group, diceRoll);
        if (IBackend* backend = resolve ? resolve(index, diceRoll) : Resolve(index, diceRoll)) {
            return backend;
        }
    }
    return Fallback(group, resolve, rng);
}

IBackend* TBaseWrrAlgorithm::Fallback(const TBackendsGroupWeights& group, std::function<IBackend*(size_t, double)> resolve, TWrrRng* rng) const noexcept {
    TBackendsGroupWeights resolvable;
    resolvable.BoundaryAndIndex.reserve(group.BoundaryAndIndex.size());
    THashMap<size_t, double> indexToPrefixSum;

    double partialSum = 0;
    for (const auto& entry : group.BoundaryAndIndex) {
        if (resolve ? resolve(entry.second, entry.first) : Resolve(entry.second, entry.first)) {
            resolvable.Add(Max<double>(0, entry.first - partialSum), entry.second);
            indexToPrefixSum[entry.second] = entry.first;
        }
        partialSum = entry.first;
    }

    if (resolvable.WeightsSum <= 0) {
        return nullptr;
    }

    double diceRoll = resolvable.WeightsSum * rng->GenRandReal4();
    size_t index = Round(resolvable, diceRoll);
    diceRoll = indexToPrefixSum[index];
    return resolve ? resolve(index, diceRoll) : Resolve(index, diceRoll);
}

}  // namespace NSrvKernel
