#include "wrr.h"
#include "hashing.h"

#include <balancer/kernel/balancer/backends.h>
#include <balancer/kernel/balancing/config.cfgproto.pb.h>
#include <balancer/kernel/balancing/updater.h>
#include <balancer/kernel/helpers/misc.h>
#include <balancer/kernel/pinger/pinger.h>
#include <balancer/kernel/ctl/children_process_common.h>
#include <balancer/modules/balancer/backends_factory.h>
#include <balancer/modules/balancer/base_algorithm.h>
#include <balancer/modules/proxy/module.h>

#include <library/cpp/proto_config/config.h>

#include <util/random/fast.h>
#include <util/random/random.h>
#include <util/stream/file.h>
#include <util/system/fs.h>

#include <contrib/libs/xxhash/xxhash.h>

namespace NSrvKernel {

namespace {
    using namespace NDynamicBalancing;

    struct TStats {
        TStats(TSharedStatsManager&) {}
        TStats(const TStats&, size_t) {}
    };

    class TDynamicTls {
    public:
        TDynamicTls(const TStats& stats, IWorkerCtl* process)
            : Stats(stats, process->WorkerId())
            , Rng(MakeSimpleShared<TWrrRng>(RandomNumber<ui32>()))
        {}

        TStats Stats;
        TSimpleSharedPtr<TWrrRng> Rng;
        TCoroutine UpdaterCoroutine_;
    };

    constexpr ui64 FIFTY_THREE_ONES = (0xFFFFFFFFFFFFFFFF >> (64 - 53));
    constexpr double FIFTY_THREE_ZEROS = 1L << 53;

    constexpr double Ui64ToDouble(const ui64 x) noexcept {
        return (x & FIFTY_THREE_ONES) / FIFTY_THREE_ZEROS;
    }

    double ComputeWeightedScore(TRequestHash reqHash, const TDynamicBackend* backend, double weight) noexcept {
        double score = Ui64ToDouble(XXH64(backend->Name().data(), backend->Name().size(), reqHash));
        return -weight / logf(score);
    }

    bool IsSkewed(const TDynamicBackend& backend, double maxSkew) noexcept {
        return backend.Selections > backend.Expectations * maxSkew;
    }

    class TAlgorithm : public TBaseWrrAlgorithm {
    private:
        TAlgorithm(IWorkerCtl* process, TIntrusivePtr<TBalancingState> state, TSimpleSharedPtr<TWrrRng> rng, TMaybe<TRequestHash> hash, double maxSkew)
            : TBaseWrrAlgorithm(process, rng.Get())
            , RequestHash_(std::move(hash))
            , BalancingState_(std::move(state))
            , Rng_(std::move(rng))
        {
            Y_VERIFY(BalancingState_);
            if (maxSkew >= 1 || !BalancingState_->GroupsWeights().empty()) {
                const auto &snapshots = BalancingState_->StateSnapshots();
                for (const auto &snapshot: snapshots) {
                    if (maxSkew >= 1 && IsSkewed(*static_cast<TDynamicBackend *>(snapshot.first), maxSkew)) {
                        Skewed_.insert(snapshot.first);
                        SkewedInGroup_[snapshot.second.GroupName]++;
                    }
                    if (snapshot.second.Weight == 0) {
                        RemoveSelected(snapshot.first);
                    }
                }
            }
        }
    public:
        TAlgorithm(const TStepParams& params, TIntrusivePtr<TBalancingState> state, TSimpleSharedPtr<TWrrRng> rng, bool useHashing, double maxSkew)
            : TAlgorithm(
                params.Descr ? &params.Descr->Process() : nullptr,
                std::move(state),
                std::move(rng),
                (useHashing && params.Hash) ? TMaybe<TRequestHash>{*params.Hash} : TMaybe<TRequestHash>{},
                maxSkew + 1
            )
        {
        }

        TAlgorithm(const TAlgorithm* source, TIntrusivePtr<TBalancingState> state, double maxSkew)
            : TAlgorithm(source->Process_, std::move(state), source->Rng_, source->RequestHash_, maxSkew + 1)
        {
            const auto& snapshots = BalancingState_->StateSnapshots();
            for (auto backend : source->Excluded_) {
                if (snapshots.contains(backend)) {
                    RemoveSelected(backend);
                }
            }
        }

        IBackend* SelectNext() noexcept override {
            return Next();
        }

        IBackend* Next() noexcept override {
            if (!Skewed_.empty()) {
                IBackend* backend = ChooseBackend(true);
                if (backend) {
                    return Track(backend);
                }
            }
            return Track(ChooseBackend(false));
        }

        IBackend* NextByName(TStringBuf name, bool) noexcept override {
            const TBackendsGroupWeights* group = BalancingState_->GroupsWeights().FindPtr(name);
            if (group && !IsExhaustedGroup(*group, name, !SkewedInGroup_.empty())) {
                return Track(ChooseBackendFromGroup(*group, !Skewed_.empty()));
            }
            return nullptr;
        }

    private:
        IBackend* ChooseBackend(bool checkSkew) noexcept {
            if (!RequestHash_) {
                if (Excluded_.size() == BalancingState_->Backends().size()) {
                    return nullptr;
                }
                return ChooseBackendFromGroup(BalancingState_->AllWeights(), checkSkew);
            }
            TMaybe<double> maxScore;
            const auto& snapshots = BalancingState_->StateSnapshots();
            auto result = snapshots.end();

            for (auto snapshot = snapshots.begin(); snapshot != snapshots.end(); ++snapshot) {
                if (IsExhaustedGroup(*snapshot, checkSkew)) {
                    continue;
                }

                TDynamicBackend* backend = static_cast<TDynamicBackend*>(snapshot->first);
                double score = ComputeWeightedScore(*RequestHash_, backend, snapshot->second.Weight);
                if (!maxScore.Defined() || score > *maxScore) {
                    maxScore = score;
                    result = snapshot;
                }
            }

            if (result == snapshots.end()) {
                return nullptr;
            }

            return ChooseBackendFromGroup(*result, checkSkew);
        }

        IBackend* Track(IBackend* backend) {
            if (backend) {
                size_t index = Round(BalancingState_->AllWeights());
                BalancingState_->Backends()[index]->AddExpectation();
                TDynamicBackend* dyn = static_cast<TDynamicBackend*>(backend);
                dyn->AddSelection();
            }
            return backend;
        }

        void RemoveSelected(IBackend* backend) noexcept override {
            if (!BalancingState_->GroupsWeights().empty() && !IsRemoved(backend)) {
                if (auto *snapshot = BalancingState_->StateSnapshots().FindPtr(backend)) {
                    auto& p = RemovedInGroup_[snapshot->GroupName];
                    p.first++;
                    if (Skewed_.contains(backend)) {
                        p.second++;
                    }
                }
            }
            TBaseWrrAlgorithm::RemoveSelected(backend);
        }
        void Reset() noexcept override {
            RemovedInGroup_.clear();
            TBaseWrrAlgorithm::Reset();
        }

        bool IsExhaustedBackend(IBackend* backend, bool checkSkew) const noexcept {
            if (IsRemoved(backend)) {
                return true;
            }
            if (checkSkew) {
                return Skewed_.contains(backend);
            }
            return false;
        }

        IBackend* Resolve(size_t index, double) const noexcept override {
            return Resolve(index, false);
        }

        IBackend* Resolve(size_t index, bool checkSkew) const noexcept {
            IBackend* backend = BalancingState_->Backends()[index].Get();
            if (!IsExhaustedBackend(backend, checkSkew)) {
                return backend;
            }
            return nullptr;
        }

        IBackend* ChooseBackendFromGroup(const TBackendsGroupWeights& group, bool checkSkew) noexcept {
            std::function<IBackend*(size_t, double)> resolve;
            if (checkSkew) {
                resolve = [this](size_t index, double){
                    return Resolve(index, true);
                };
            }
            return SelectBackend(group, resolve);
        }

        IBackend* ChooseBackendFromGroup(const std::pair<IBackend* const, TBackendStateSnapshot>& snapshot, bool checkSkew) noexcept {
            const TBackendsGroupWeights* group = BalancingState_->GroupsWeights().FindPtr(snapshot.second.GroupName);
            if (group) {
                if (IBackend* backend = ChooseBackendFromGroup(*group, checkSkew)) {
                    return backend;
                }
            }
            if (!IsExhaustedBackend(snapshot.first, checkSkew)) {
                return snapshot.first;
            }
            return nullptr;
        }

        bool IsExhaustedGroup(const std::pair<IBackend* const, TBackendStateSnapshot>& snapshot, bool checkSkew) const noexcept {
            const TBackendsGroupWeights* group = BalancingState_->GroupsWeights().FindPtr(snapshot.second.GroupName);
            if (group) {
                return IsExhaustedGroup(*group, snapshot.second.GroupName, checkSkew);
            }
            return IsExhaustedBackend(snapshot.first, checkSkew);
        }
        bool IsExhaustedGroup(const TBackendsGroupWeights& group, TStringBuf name, bool checkSkew) const noexcept {
            const size_t total = group.BoundaryAndIndex.size();
            size_t removedNotSkewed = 0;
            if (auto* removed = RemovedInGroup_.FindPtr(name)) {
                if (total <= removed->first) {
                    return true;
                }
                removedNotSkewed = removed->first - removed->second;
            }
            if (checkSkew) {
                return total <= SkewedInGroup_.Value(name, 0) + removedNotSkewed;
            }
            return false;
        }

        TMaybe<TRequestHash> RequestHash_;
        TIntrusivePtr<TBalancingState> BalancingState_;
        THashMap<TString, std::pair<size_t, size_t>> RemovedInGroup_;
        THashMap<TString, size_t> SkewedInGroup_;
        TSimpleSharedPtr<TWrrRng> Rng_;
        THashSet<IBackend*> Skewed_;
    };

    template <const char* NAME, bool HASHING>
    class TDynamicBackends :
        public TBackendsWithTLS<TDynamicBackends<NAME, HASHING>, TDynamicTls, NAME>,
        public TModuleParams,
        public NConfig::IConfig::IFunc
    {
    private:
        using TBackendsWithTlsBase = TBackendsWithTLS<TDynamicBackends<NAME, HASHING>, TDynamicTls, NAME>;

        TDynamicBackends(const TModuleParams& mp, const TBackendsUID& uid, NProtoConfig::TStackUnknownFieldCb cb, TBackendGroupUpdater* updater)
            : TBackendsWithTlsBase(mp)
            , TModuleParams(mp)
            , Uid_(uid)
            , Stats_(Control->SharedStatsManager())
        {
            NProtoConfig::ParseConfig(*Config, Config_, cb);

            for (const auto& backend : TBackendsWithTlsBase::BackendDescriptors()) {
                if (!backend->Module()->TypeName().EndsWith("Nproxy::TModule") &&
                    !backend->Module()->TypeName().EndsWith("Napp_host_backend::TModule") &&
                    !backend->Module()->TypeName().EndsWith("Nerrordocument::TModule") &&
                    !backend->Module()->TypeName().EndsWith("Ndummy::TModule")) {
                    ythrow TConfigParseError{} << "one of dynamic backends is not a proxy or app_host_backend or errordocument or dummy module";
                }
            }

            BackendGroupUpdater_ = Control->DynamicBalancingUpdaterManager().RegisterBackends(
                Uid_.Value,
                TBackendsWithTlsBase::BackendDescriptors(),
                Config_,
                updater
            );
        }

        TDynamicBackends(
            const TModuleParams& mp, const TBackendsUID& uid, const THashMap<TString,
            TBackendDescriptor::TRef>& backends, TBackendGroupUpdater* updater
        ) : TDynamicBackends(mp, uid, [this, &backends](const NProtoConfig::TKeyStack&, const TString& key, NConfig::IConfig::IValue* value) {
            auto it = backends.find(key);
            if (it != backends.end()) {
                TBackendsWithTlsBase::Add(it->second);
            } else {
                TBackendsWithTlsBase::Add(MakeHolder<TBackendDescriptor>(Copy(value->AsSubConfig()), key));
            }
        }, updater)
        {
        }
    public:
        TDynamicBackends(const TModuleParams& mp, const TBackendsUID& uid)
            : TDynamicBackends(mp, uid, [this](const NProtoConfig::TKeyStack&, const TString& key, NConfig::IConfig::IValue* value) {
            TBackendsWithTlsBase::Add(MakeHolder<TBackendDescriptor>(Copy(value->AsSubConfig()), key));
        }, nullptr)
        {
        }

        bool CanUpdate() const noexcept override {
            return true;
        }

        THolder<IBackends> Update(const TModuleParams& mp) const override {
            THashMap<TString, TBackendDescriptor::TRef> backends;
            for (const auto& backend : TBackendsWithTlsBase::BackendDescriptors()) {
                backends.emplace(backend->Name(), backend);
            }
            THolder<IBackends> updated(new TDynamicBackends(mp, Uid_, backends, BackendGroupUpdater_.Get()));
            updated->SetCheckParameters(TBackendsWithTlsBase::CheckParameters_);
            return updated;
        }

        THolder<IAlgorithm> InvalidateAlgorithm(IAlgorithm& algorithm, const THashSet<IBackend*>& backendsInUse) noexcept override {
            auto state = BackendGroupUpdater_->AcquireState();
            const auto& snapshots = state->StateSnapshots();
            for (const auto backend : backendsInUse) {
                if (!snapshots.contains(backend)) {
                    return nullptr; // can not invalidate: new state doesn't contain backend in use
                }
            }
            TAlgorithm* source = dynamic_cast<TAlgorithm*>(&algorithm);
            Y_VERIFY(source);
            return MakeHolder<TAlgorithm>(source, state, Config_.max_skew());
        }

        THolder<TDynamicTls> DoInit(IWorkerCtl* process) noexcept override {
            auto type = process->WorkerType();

            auto tls = MakeHolder<TDynamicTls>(Stats_, process);

            if (type == NProcessCore::TChildProcessType::Updater) {
                tls->UpdaterCoroutine_ = BackendGroupUpdater_->Start(process);
            }

            return tls;
        }

        void DoDispose(IWorkerCtl* process, TDynamicTls&) noexcept override {
            if (process->WorkerType() == NProcessCore::TChildProcessType::Updater) {
                BackendGroupUpdater_->Stop();
            }
        }

        THolder<IAlgorithm> ConstructAlgorithm(const TStepParams& params) noexcept override {
            TSimpleSharedPtr<TWrrRng> rng;
            if (params.Descr) {
                rng = TBackendsWithTlsBase::GetTls(params.Descr->Process()).Rng;
            }
            return MakeHolder<TAlgorithm>(params, BackendGroupUpdater_->AcquireState(), rng, HASHING, Config_.max_skew());
        }

        void DumpBackends(NJson::TJsonWriter& out, const TDynamicTls&) const noexcept override {
            out.OpenMap();
            out.Write("id", Uid_.Value);
            BackendGroupUpdater_->AcquireState()->Dump(out);
            out.CloseMap();
        }

        TBackendCheckResult CheckBackends(IWorkerCtl& proc, bool runtimeCheck) noexcept override {
            TBackendCheckParameters parameters = TBackendsWithTlsBase::ActualCheckParameters(proc);
            if (TBackendsWithTlsBase::ShouldSkipCheck(parameters)) {
                return TBackendCheckResult{TBackendCheckResult::EStatus::Skipped};
            }
            TIntrusivePtr<TNotReadyBackendSet> notReady = BackendGroupUpdater_->GetNotReadySet();
            TIntrusivePtr<TBlacklist> blacklist = BackendGroupUpdater_->GetCurrentBlacklist();
            TMaybe<TRequest> request;
            if (TMaybe<TPingerConfig> config = BackendGroupUpdater_->PingerConfig()) {
                request = config->PingRequest;
            }
            size_t checkedBackendsCount = 0;
            TBackendCheckResult result = TBackendsWithTlsBase::VisitBackends(
                proc, runtimeCheck,
                [this, &notReady, &blacklist, &request, &checkedBackendsCount](
                        IWorkerCtl& proc, TBackendDescriptor::TRef backend, TBackendCheckResult& result, bool runtimeCheck) {
                    if (!blacklist || !blacklist->Contains(backend->Name())) {
                        ++checkedBackendsCount;
                        if (notReady && notReady->contains(backend->Name())) {
                            result.Errors.emplace_back(Y_MAKE_ERROR(yexception{} << "SD claims this backend is not ready"));
                        } else if (runtimeCheck) {
                            if (!backend->Enabled() || backend->WeightFromPing() <= 0) {
                                result.Errors.emplace_back(Y_MAKE_ERROR(yexception{} << "this backend is unhealthy"));
                            }
                            // Also, we can treat backends that marked themselves degraded (backend->Degraded()) as ill,
                            // but for now let's count them as available
                        } else if (TError err = TBackendsWithTlsBase::CheckBackend(proc, backend, request)) {
                            result.Errors.emplace_back(std::move(err));
                        }
                    }
                });
            TBackendsWithTlsBase::PostProcessCheckResult(result, parameters, checkedBackendsCount, result.Errors.size());
            return result;
        }

        void DumpBalancingState(NJson::TJsonWriter& out, const TDynamicTls&) const noexcept override {
            out.OpenMap();
            BackendGroupUpdater_->DumpBalancingState(out);
            out.CloseMap();
        }

        void ChangeNotReadySet(TIntrusivePtr<TNotReadyBackendSet> set) const noexcept override {
            BackendGroupUpdater_->ChangeNotReadySet(std::move(set));
        }

    private:
        NSrvKernel::NDynamicBalancing::TConfig Config_;
        const TBackendsUID& Uid_;
        THolder<NSrvKernel::NDynamicBalancing::TBackendGroupUpdater> BackendGroupUpdater_;
        TStats Stats_;
    };
}


namespace NModBalancer::NDynamic {
    constexpr char DYNAMIC[] = "dynamic";

    INodeHandle<IBackends> *HandleWrr() {
        return TDynamicBackends<DYNAMIC, false>::Handle();
    }

    constexpr char DYNAMIC_HASHING[] = "dynamic_hashing";

    INodeHandle<IBackends> *HandleHashing() {
        return TDynamicBackends<DYNAMIC_HASHING, true>::Handle();
    }
}

} // namespace NSrvKernel
