#include "backends_factory.h"
#include "base_algorithm.h"
#include "leastconn.h"

#include <balancer/kernel/custom_io/stream.h>
#include <balancer/kernel/module/iface.h>
#include <balancer/kernel/module/module.h>

#include <library/cpp/containers/intrusive_rb_tree/rb_tree.h>

namespace NSrvKernel::NModBalancer {


namespace NLeastConn {

void TBackend::DoOnCompleteRequest(const TDuration&) noexcept {
    Tls_.BackendsTree.Erase(BackendNode_);
    BackendNode_.Connections_ -= 1;
    Tls_.BackendsTree.Insert(BackendNode_);
}

void TBackend::DoOnFailRequest(const TError&, const TDuration&) noexcept {
    Tls_.BackendsTree.Erase(BackendNode_);
    BackendNode_.Connections_ -= 1;
    Tls_.BackendsTree.Insert(BackendNode_);
}

TNode* TBackend::GetBackendNode() noexcept {
    return &BackendNode_;
}

void TBackend::PrintInfo(NJson::TJsonWriter& out) const noexcept {
    out.Write("connections", BackendNode_.Connections_);
    PrintSuccFailRate(out);
    PrintProxyInfo(out);
}
}

BACKENDS_WITH_TLS(leastconn), public TModuleParams {
private:
    class TLeastConnAlgorithm : public TAlgorithmWithRemovals {
    public:
        TLeastConnAlgorithm(const TStepParams& params, TTls& tls) noexcept
            : TAlgorithmWithRemovals(&params.Descr->Process())
            , Tls_(tls)
        {}

        void RemoveSelected(IBackend* backend) noexcept override {
            TAlgorithmWithRemovals::RemoveSelected(backend);
            IncConns(backend);
        }

        IBackend* Next() noexcept override {
            auto it = Tls_.BackendsTree.Begin();
            for (;it != Tls_.BackendsTree.End(); ++it) {
                if (!IsRemoved(it->Backend_)) {
                    break;
                }
            }
            if (it == Tls_.BackendsTree.End()) {
                return nullptr;
            } else {
                return it->Backend_;
            }
        }

        IBackend* NextByName(TStringBuf name, bool /*allowZeroWeights*/) noexcept override {
            auto it = Tls_.NamedBackends.find(name);
            if (it != Tls_.NamedBackends.end()) {
                return it->second;
            }

            return nullptr;
        }

    private:
        void IncConns(IBackend* backend) noexcept {
            auto* back = reinterpret_cast<NLeastConn::TBackend*>(backend);
            auto node = back->GetBackendNode();
            Tls_.BackendsTree.Erase(*node);
            back->GetBackendNode()->Connections_ += 1;
            back->GetBackendNode()->Usage_ += 1;
            Tls_.BackendsTree.Insert(*node);
        }

    private:
        TTls& Tls_;
    };

// Initialization
// --------------------------------------------------------------------------------
public:
    TBackends(const TModuleParams& mp, const TBackendsUID& uid)
        : TBackendsWithTLS(mp)
        , TModuleParams(mp)
        , BackendsId_(uid.Value)
    {
        Config->ForEach(this);
    }

private:
    THolder<TTls> DoInit(IWorkerCtl*) noexcept override {
        auto tls = MakeHolder<TTls>();
        for (auto& i : BackendDescriptors()) {
            tls->Backends.emplace_back(i, *tls.Get());
        }
        for (auto& i : tls->Backends) {
            tls->NamedBackends[i.Name()] = &i;
            tls->BackendsTree.Insert(i.GetBackendNode());
        }
        return tls;
    }

    void ProcessPolicyFeatures(const TPolicyFeatures& features) override {
        if (features.WantsHash) {
            PrintOnce("WARNING in balancer2/leastconn: "
                      "policies with hash are not supported and will be ignored");
        }
    }

    START_PARSE {
        Add(MakeHolder<TBackendDescriptor>(Copy(value->AsSubConfig()), key));

        return;
    } END_PARSE

// --------------------------------------------------------------------------------


    void DumpBackends(NJson::TJsonWriter& out, const TTls& tls) const noexcept override {
        out.OpenMap();
        out.Write("id", BackendsId_);
        out.OpenArray("backends");
        for (const auto& backend : tls.Backends) {
            out.OpenMap();
            backend.PrintInfo(out);
            out.CloseMap();
        }
        out.CloseArray();
        out.CloseMap();
    }
// --------------------------------------------------------------------------------


// Functionality
// --------------------------------------------------------------------------------
    THolder<IAlgorithm> ConstructAlgorithm(const TStepParams& params) noexcept override {
        return MakeHolder<TLeastConnAlgorithm>(params, GetTls(params.Descr->Process()));
    }
// --------------------------------------------------------------------------------


// State
// --------------------------------------------------------------------------------
private:
    size_t BackendsId_ = 0;
// --------------------------------------------------------------------------------
};

INodeHandle<IBackends>* NLeastConn::Handle() {
    return Nleastconn::TBackends::Handle();
}

}  // namespace NSrvKernel::NModBalancer
