#include "choose_policy.h"

#include <library/cpp/containers/stack_vector/stack_vec.h>

#include <util/random/easy.h>

namespace NPassport::NDbPool {
    /**
     * Random logic
     *
     * |start0,weight0,idx0|...|startI,weightI,idxI|...|startN,weightN,idxN|
     *
     * 1) `start...` are sorted.
     * 2) `startI` is sum of all previous weights.
     * 3) we get random number in [0, weightSum) to provide logic of weight.
     * 4) we don't add all hosts to the array `weight` times -
     *      to avoid choosing bad host more than one time
     * 5) we need to correct `startI+1...` after attempt to provide correct weight for `startI-1` host.
     *     Yes, it is O(N^2) if there is no one alive host.
     * 6) `weightI` required only to remember how to correct `startI+1...`
     * 7) `idxI` is result of algorithm
     * 8) we try every host only once
     */

    // TODO
    // в фабрике надо подготовить все возможные значения для массивов.
    // в chooser-е надо выбирать их за O(1). Тогда сложность всего алгоритма будет O(N*log(N))

    TRandomChooser::TRandomChooser(const TState& state)
        : State_(state)
    {
    }

    IChooser::TOptionalIdx TRandomChooser::TryGetIdx(size_t rnd) {
        Y_VERIFY(!State_.Weights.empty());
        Y_VERIFY(rnd < State_.WeightSum);

        auto it = std::lower_bound(
            State_.Weights.begin(),
            State_.Weights.end(),
            THostWeight{.Start = rnd},
            [](const THostWeight& l, const THostWeight& r) { return l.Start < r.Start; });
        if (it == State_.Weights.end()) {
            it = State_.Weights.begin() + (State_.Weights.size() - 1);
        }
        if (it != State_.Weights.begin() && rnd < it->Start) {
            --it;
        }

        size_t res = it->Idx;

        size_t weightToErase = it->Weight;

        State_.WeightSum -= weightToErase;
        for (auto iter = it; iter != State_.Weights.end(); ++iter) {
            iter->Start -= weightToErase;
        }
        State_.Weights.erase(it);

        return res;
    }

    IChooser::TOptionalIdx TRandomChooser::TryGetIdx() {
        if (State_.Weights.empty()) {
            return {};
        }

        return TryGetIdx(RandomNumber(State_.WeightSum));
    }

    TRandomFactory::TRandomFactory(const std::vector<TDbHost>& hosts)
        : State_(PrepareState(hosts))
    {
    }

    TChooserPtr TRandomFactory::CreateChooser() {
        return std::make_unique<TRandomChooser>(State_);
    }

    TRandomChooser::TState TRandomFactory::PrepareState(const std::vector<TDbHost>& hosts) {
        TRandomChooser::TState res;
        res.Weights.reserve(hosts.size());

        for (size_t idx = 0; idx < hosts.size(); ++idx) {
            const size_t v = hosts[idx].Weight;

            res.Weights.push_back(TRandomChooser::THostWeight{
                .Start = res.WeightSum,
                .Weight = v,
                .Idx = idx,
            });

            res.WeightSum += v;
        }

        return res;
    }

    TChooseFactoryPtr TRandomFactory::Create(const std::vector<TDbHost>& hosts) {
        return std::make_unique<TRandomFactory>(hosts);
    }
}
