#include "backends_factory.h"
#include "base_algorithm.h"
#include "by_location.h"
#include "weights_file.h"

#include <balancer/kernel/balancing/per_worker_backend.h>
#include <util/random/fast.h>
#include <util/random/random.h>

namespace {

using namespace NSrvKernel;

class TBackend : public TPerWorkerBaseBackend {
public:
    explicit TBackend(TBackendDescriptor::TRef backend)
        : TPerWorkerBaseBackend(std::move(backend))
        , EffectiveWeight_(TPerWorkerBaseBackend::OriginalWeight())
    {}
    double EffectiveWeight() const noexcept {
        return Max(EffectiveWeight_, 0.0);
    }
    void SetEffectiveWeight(double weight) noexcept {
        EffectiveWeight_ = weight;
    }
    void ResetEffectiveWeight() noexcept {
        EffectiveWeight_ = OriginalWeight();
    }
    bool Available() const noexcept {
        return Available_;
    }

    void PrintInfo(NJson::TJsonWriter& out, bool preferred) const noexcept {
        out.Write("name", Name());
        out.Write("preferred", preferred);
        out.Write("original_weight", OriginalWeight());
        out.Write("weight", EffectiveWeight());
        out.Write("available", Available());
        PrintSuccFailRate(out);
        PrintProxyInfo(out);
    }

    void UpdateAvailability(IWorkerCtl& ctl) {
        Available_ = Backend_->Module()->CheckBackends(ctl, true).Status != TBackendCheckResult::EStatus::Failed;
    }
private:
    double EffectiveWeight_;
    bool Available_ = true;
};

}

namespace NSrvKernel::NModBalancer {

BACKENDS_TLS(by_location) {
    TExternalWeightsFileReReader WeightsFileChecker;
    TVector<TBackend> Backends;
    TBackendsGroupWeights Weights;
    THolder<TWrrRng> Rng;
    TBackend* PreferredBackend = nullptr;
    TSharedFileReReader PreferredBackendSwitch;
    size_t PessimizationLimit = 0;
    TInstant LastAvailabilityCheck;
    THashMap<TString, TBackend*> NamedBackends;
};

BACKENDS_WITH_TLS(by_location), public TModuleParams {
private:

    class TByLocationAlgorithm : public TBaseWrrAlgorithm {
    public:
        TByLocationAlgorithm(const TStepParams& stepParams, TTls& tls) noexcept
            : TBaseWrrAlgorithm(&stepParams.Descr->Process(), tls.Rng.Get())
            , Tls_(tls)
            , PreferredBackend_(Tls_.PreferredBackend)
        {
            size_t pessimizedCount = 0;
            // first, count zero-weighted
            for (auto &backend: Tls_.Backends) {
                if (backend.EffectiveWeight() <= 0) {
                    ++pessimizedCount;
                    RemoveSelected(&backend);
                }
            }
            if (pessimizedCount < Tls_.PessimizationLimit) {
                // if still can pessimize some locations, use Available() == false,
                // but try keep PreferredBackend_ used
                for (auto &backend: Tls_.Backends) {
                    if (!backend.Available() && backend.EffectiveWeight() > 0 && &backend != PreferredBackend_) {
                        ++pessimizedCount;
                        RemoveSelected(&backend);
                        if (pessimizedCount == Tls_.PessimizationLimit) {
                            break;
                        }
                    }
                }
                // if still can pessimize something, check PreferredBackend_
                if (pessimizedCount < Tls_.PessimizationLimit && PreferredBackend_ && !PreferredBackend_->Available()) {
                    RemoveSelected(PreferredBackend_);
                }
            }
        }

        IBackend* Next() noexcept override {
            if (PreferredBackend_) {
                TBackend* backend = PreferredBackend_;
                PreferredBackend_ = nullptr;
                if (!IsRemoved(backend)) {
                    return backend;
                }
            }
            if (Excluded_.size() == Tls_.Backends.size()) {
                return nullptr;
            }

            return SelectBackend(Tls_.Weights);
        }

        IBackend* NextByName(TStringBuf name, bool allowZeroWeights) noexcept override {
            auto it = Tls_.NamedBackends.find(name);
            if (it != Tls_.NamedBackends.end()) {
                if (allowZeroWeights || !IsRemoved(it->second)) {
                    return it->second;
                }
            }
            return nullptr;
        }

        IBackend* NextByHash(IHashProcessor& hash) noexcept override {
            if (Excluded_.size() == Tls_.Backends.size()) {
                return nullptr;
            }

            if (hash.Has()) {
                TRequestHash hashValue = hash.Get();
                double hashTarget = static_cast<double>(hashValue) / MaxCeil<TRequestHash>();
                hashTarget *= Tls_.Weights.WeightsSum;
                size_t index = Round(Tls_.Weights, hashTarget);
                return Resolve(index, hashTarget);
            }

            double hashTarget = 0;
            std::function<IBackend*(size_t, double)> resolve = [this, &hashTarget](size_t index, double diceRoll) {
                hashTarget = diceRoll;
                return Resolve(index, diceRoll);
            };
            IBackend* backend = SelectBackend(Tls_.Weights, resolve);
            if (backend) {
                hash.Set(ClampDoubleToUI64(hashTarget / Tls_.Weights.WeightsSum * MaxCeil<TRequestHash>()));
            }
            return backend;
        }
    private:
        IBackend* Resolve(size_t index, double) const noexcept override {
            IBackend* backend = &Tls_.Backends[index];
            if (!IsRemoved(backend)) {
                return backend;
            }
            return nullptr;
        }
        TTls& Tls_;
        TBackend* PreferredBackend_;
    };

public:
    TBackends(const TModuleParams& mp, const TBackendsUID& uid)
        : TBackendsWithTLS(mp)
        , TModuleParams(mp)
        , BackendsId_(uid.Value)
    {
        Config->ForEach(this);
    }

private:
    THolder<TTls> DoInit(IWorkerCtl* process) noexcept override {
        auto tls = MakeHolder<TTls>();

        if (WeightsFilename_) {
            tls->WeightsFileChecker = TExternalWeightsFileReReader(*process, WeightsFilename_, TDuration::Seconds(1), NAME);
        }

        tls->Rng = MakeHolder<TReallyFastRng32>(RandomNumber<ui32>());

        const size_t count = BackendDescriptors().size();
        tls->Backends.reserve(count);
        tls->Weights.BoundaryAndIndex.reserve(count);
        for (size_t i = 0; i < count; ++i) {
            tls->Backends.emplace_back(BackendDescriptors()[i]);
        }
        Sort(tls->Backends, [](const TBackend& l, const TBackend& r){
            return l.Name() < r.Name();
        });
        for (size_t i = 0; i < count; ++i) {
            tls->Weights.Add(tls->Backends[i].EffectiveWeight(), i);
        }

        for (auto& backend : tls->Backends) {
            tls->NamedBackends[backend.Name()] = &backend;
        }

        if (PreferredLocation_) {
            tls->PreferredBackend = tls->NamedBackends.Value(PreferredLocation_, nullptr);
        } else {
            tls->PreferredBackend = nullptr;
        }

        if (PreferredLocationSwitch_) {
            tls->PreferredBackendSwitch = process->SharedFiles()->FileReReader(PreferredLocationSwitch_, TDuration::Seconds(1));
        }

        return tls;
    }

    START_PARSE {
        ON_KEY("weights_file", WeightsFilename_) {
            return;
        }
        ON_KEY("preferred_location", PreferredLocation_) {
            return;
        }
        ON_KEY("preferred_location_switch", PreferredLocationSwitch_) {
            return;
        }

        Add(MakeHolder<TBackendDescriptor>(Copy(value->AsSubConfig()), key));

        return;
    } END_PARSE

    void DumpBackends(NJson::TJsonWriter& out, const TTls& tls) const noexcept override {
        out.OpenMap();
        out.Write("id", BackendsId_);
        out.OpenArray("backends");
        for (const auto& backend : tls.Backends) {
            out.OpenMap();
            backend.PrintInfo(out, &backend == tls.PreferredBackend);
            out.Write("update_weights_id", tls.WeightsFileChecker.LatestData().Id());
            out.CloseMap();
        }
        out.CloseArray();
        out.CloseMap();
    }

    void DumpWeightsFileTags(NJson::TJsonWriter& out, const TTls& tls) const noexcept override {
        tls.WeightsFileChecker.WriteWeightsFileTag(out);
    }

    TBackendCheckResult CheckBackends(IWorkerCtl& proc, bool runtimeCheck) noexcept override {
        TBackendCheckParameters parameters = ActualCheckParameters(proc);
        if (ShouldSkipCheck(parameters)) {
            return TBackendCheckResult{TBackendCheckResult::EStatus::Skipped};
        }
        bool allSkipped = true;
        size_t failuresCount = 0;
        TBackendCheckResult result = VisitBackends(proc, runtimeCheck, [&allSkipped, &failuresCount](
                IWorkerCtl& proc, TBackendDescriptor::TRef backend, TBackendCheckResult& result, bool runtimeCheck) {
            TBackendCheckResult childResult = backend->Module()->CheckBackends(proc, runtimeCheck);
            if (childResult.Status != TBackendCheckResult::EStatus::Skipped) {
                allSkipped = false;
            }
            if (childResult.Status == TBackendCheckResult::EStatus::Failed) {
                ++failuresCount;
            }
            for (auto& err : childResult.Errors) {
                result.Errors.emplace_back(std::move(err));
            }
        });
        if (result.Status == TBackendCheckResult::EStatus::Success && allSkipped) {
            result.Status = TBackendCheckResult::EStatus::Skipped;
        }
        PostProcessCheckResult(result, parameters, BackendDescriptors().size(), failuresCount);
        return result;
    }

// Functionality
// --------------------------------------------------------------------------------
    THolder<IAlgorithm> ConstructAlgorithm(const TStepParams& params) noexcept override {
        auto& tls = GetTls(params.Descr->Process());
        UpdateWeights(*params.Descr, tls);
        UpdatePreferences(tls);
        UpdateAvailability(tls, params.Descr->Process());
        return MakeHolder<TByLocationAlgorithm>(params, tls);
    }

    void UpdateWeights(const TConnDescr& descr, TTls& tls) noexcept {
        if (tls.WeightsFileChecker.UpdateWeights(descr)) {
            const auto& entries = tls.WeightsFileChecker.Entries();
            tls.Weights.Clear();
            size_t idx = 0;
            for (auto& backend : tls.Backends) {
                const auto entry = entries.find(backend.Name());
                if (entry != entries.end()) {
                    backend.SetEffectiveWeight(entry->second);
                } else {
                    backend.ResetEffectiveWeight();
                }
                tls.Weights.Add(backend.EffectiveWeight(), idx++);
            }
        }
    }

    void UpdatePreferences(TTls& tls) noexcept {
        TString preferredBackend = tls.PreferredBackendSwitch.Data().Exists() ? tls.PreferredBackendSwitch.Data().Data() : PreferredLocation_;
        if (!preferredBackend) {
            tls.PreferredBackend = nullptr;
        } else if (!tls.PreferredBackend || tls.PreferredBackend->Name() != preferredBackend) {
            tls.PreferredBackend = nullptr;
            for (auto& backend : tls.Backends) {
                if (backend.Name() == preferredBackend) {
                    tls.PreferredBackend = &backend;
                    break;
                }
            }
        }
    }

    void UpdateAvailability(TTls& tls, IWorkerCtl& ctl) noexcept {
        if (Now() - tls.LastAvailabilityCheck > TDuration::Seconds(1)) {
            for (auto &backend: tls.Backends) {
                backend.UpdateAvailability(ctl);
            }
            tls.LastAvailabilityCheck = Now();
        }
        auto params = ActualCheckParameters(ctl);
        size_t quorumCount = tls.Backends.size(); // by default forbid to switch off any backends
        if (params.AmountQuorum) {
            quorumCount = *params.AmountQuorum;
        }
        if (params.Quorum) {
            size_t q = static_cast<size_t>(std::ceil(*params.Quorum * tls.Backends.size()));
            if (params.AmountQuorum) {
                quorumCount = Max(q, quorumCount);
            } else {
                quorumCount = q;
            }
        }
        if (quorumCount >= tls.Backends.size()) {
            tls.PessimizationLimit = 0;
        } else {
            tls.PessimizationLimit = tls.Backends.size() - quorumCount;
        }
    }

private:
    TString WeightsFilename_;
    TString PreferredLocation_;
    TString PreferredLocationSwitch_;
    size_t BackendsId_ = 0;
};

INodeHandle<IBackends>* NByLocation::Handle() {
    return TBackends::Handle();
}

}  // namespace NSrvKernel::NModBalancer
