#include "backends_factory.h"
#include "pwr2.h"

#include <balancer/kernel/balancing/per_worker_backend.h>
#include <util/random/random.h>
#include <util/random/fast.h>

namespace NSrvKernel::NModBalancer {

namespace {
    class TBackend : public TPerWorkerBaseBackend {
    public:
        TBackend(TBackendDescriptor::TRef descr)
            : TPerWorkerBaseBackend(std::move(descr))
            , Load_(0)
        {}

        void DoOnCompleteRequest(const TDuration&) noexcept override {}

        void DoOnFailRequest(const TError&, const TDuration&) noexcept override {}

    public:
        volatile ui64 Load_;
    };
}

BACKENDS_TLS(pwr2) {
    explicit TTls(ui32 seed) : Rnd(seed) {}

    TDeque<TBackend> Backends;
    THashMap<TString, TBackend*> NamedBackends;
    TReallyFastRng32 Rnd;
};

BACKENDS_WITH_TLS(pwr2), public TModuleParams {
private:
    class TPowerOfTwoChoices : public IAlgorithm {
    public:
        TPowerOfTwoChoices(TBackends* parent, const TStepParams& params, TTls& tls)
            : IAlgorithm(&params.Descr->Process())
            , Parent_(parent)
            , Tls_(tls)
        {}

        ~TPowerOfTwoChoices() {
            for (size_t i = 0; i < Used_.size(); i++) {
                AtomicSub(Used_[i]->Load_, 1 + (i < UsedBeforeLastCall_ ? Parent_->FailureWeight_ : 0));
            }
        }

        void RemoveSelected(IBackend* backend) noexcept override {
            TBackend* casted = dynamic_cast<TBackend*>(backend);
            Y_ASSERT(casted);
            AtomicIncrement(casted->Load_);
            Used_.emplace_back(casted);

            size_t id = casted - &Tls_.Backends[0];
            for (size_t previous : Unshifted_) {
                Y_ASSERT(id != previous);
                id -= (id > previous);
            }
            Unshifted_.emplace_back(id);
        }

        void Reset() noexcept override {
            Unshifted_.clear();
        }

        IBackend* Next() noexcept override {
            while (UsedBeforeLastCall_ < Used_.size()) {
                AtomicAdd(Used_[UsedBeforeLastCall_++]->Load_, Parent_->FailureWeight_);
            }

            size_t have = Tls_.Backends.size() - Used_.size();

            if (have == 0) {
                return nullptr;
            }

            if (have == 1) {
                size_t result = 0;
                for (size_t i = Unshifted_.size(); i--;)
                    result += (result >= Unshifted_[i]);
                return &Tls_.Backends[result];
            }

            size_t firstIdx = Tls_.Rnd.Uniform(have * 2); // + 1 bit as a tiebreaker
            size_t secondIdx = Tls_.Rnd.Uniform(have - 1);
            bool tiebreaker = firstIdx % 2;
            firstIdx /= 2;
            secondIdx += (secondIdx >= firstIdx);
            for (size_t i = Unshifted_.size(); i--;) {
                firstIdx += (firstIdx >= Unshifted_[i]);
                secondIdx += (secondIdx >= Unshifted_[i]);
            }
            // `x < y + tiebreaker` <=> `tiebreaker ? x <= y : x < y`.
            bool selectFirst = Tls_.Backends[firstIdx].Load_ < Tls_.Backends[secondIdx].Load_ + tiebreaker;
            return &Tls_.Backends[selectFirst ? firstIdx : secondIdx];
        }

        IBackend* NextByName(TStringBuf name, bool /*allowZeroWeights*/) noexcept override {
            auto it = Tls_.NamedBackends.find(name);
            return it != Tls_.NamedBackends.end() ? it->second : nullptr;
        }

    private:
        TVector<TBackend*> Used_;
        TVector<size_t> Unshifted_;
        TBackends* Parent_ = nullptr;
        TTls& Tls_;
        size_t UsedBeforeLastCall_ = 0;
    };

// Initialization
// --------------------------------------------------------------------------------
public:
    TBackends(const TModuleParams& mp, const TBackendsUID& uid)
        : TBackendsWithTLS(mp)
        , TModuleParams(mp)
        , Seed_(RandomNumber<ui32>())
        , BackendsId_(uid.Value)
    {
        Config->ForEach(this);
    }

private:
    THolder<TTls> DoInit(IWorkerCtl*) noexcept override {
        auto tls = MakeHolder<TTls>(Seed_);
        for (auto& i : BackendDescriptors()) {
            tls->Backends.emplace_back(i);
        }
        for (auto& i : tls->Backends) {
            tls->NamedBackends[i.Name()] = &i;
        }
        return tls;
    }

    START_PARSE {
        ON_KEY("failure_weight", FailureWeight_) {
            return;
        }

        Add(MakeHolder<TBackendDescriptor>(Copy(value->AsSubConfig()), key));

        return;
    } END_PARSE
// --------------------------------------------------------------------------------


// Statistics
// --------------------------------------------------------------------------------
    void DumpBackends(NJson::TJsonWriter& out, const TTls& tls) const noexcept override {
        out.OpenMap();
        out.Write("id", BackendsId_);
        out.OpenArray("backends");
        for (const auto& backend : tls.Backends) {
            out.OpenMap();
            backend.PrintProxyInfo(out);
            backend.PrintSuccFailRate(out);
            out.CloseMap();
        }
        out.CloseArray();
        out.CloseMap();
    }
// --------------------------------------------------------------------------------

// Functionality
// --------------------------------------------------------------------------------
    THolder<IAlgorithm> ConstructAlgorithm(const TStepParams& params) noexcept override {
        return MakeHolder<TPowerOfTwoChoices>(this, params, GetTls(params.Descr->Process()));
    }
// --------------------------------------------------------------------------------


// State
// --------------------------------------------------------------------------------
private:
    ui32 Seed_ = 0;
    size_t FailureWeight_ = 0;
    size_t BackendsId_ = 0;
// --------------------------------------------------------------------------------
};

INodeHandle<IBackends>* NPowerOfTwoChoices::Handle() {
    return TBackends::Handle();
}

}  // namespace NSrvKernel::NModBalancer
