#include "backends_factory.h"
#include "base_algorithm.h"
#include "weights_file.h"
#include "rrobin.h"

#include <balancer/kernel/balancing/per_worker_backend.h>
#include <balancer/kernel/fs/kv_file_consumer.h>
#include <balancer/kernel/fs/shared_file_rereader.h>
#include <balancer/kernel/fs/shared_files.h>
#include <balancer/kernel/helpers/cast.h>
#include <balancer/kernel/log/errorlog.h>
#include <balancer/kernel/module/iface.h>

#include <library/cpp/containers/intrusive_avl_tree/avltree.h>

#include <algorithm>
#include <utility>

namespace {

struct TBackendCompare {
    template <class T>
    static bool Compare(const T& l, const T& r) noexcept {
        return T::Compare(l, r);
    }
};

class TBackendException : public yexception {};

using namespace NSrvKernel;

class TBackend
    : public TPerWorkerBaseBackend
    , public TAvlTreeItem<TBackend, TBackendCompare>
    , public TIntrusiveListItem<TBackend>
{
public:
    TBackend(TBackendDescriptor::TRef descr)
        : TPerWorkerBaseBackend(std::move(descr))
        , FileWeight_(Backend_->Weight())
        , StaticWeight_(Backend_->Weight())
        , Weight_(0.0)
        , WeightStep_(0.0)
    {
        if (Backend_->Weight() != 0.0) {
            Weight_ = 1.0 / Backend_->Weight();
            WeightStep_ = 1.0 / Backend_->Weight();
        }
        SetWeight(OriginalWeight());
    }

    static bool Compare(const TBackend& l, const TBackend& r) noexcept {
        return (l.Weight() < r.Weight()) || ((l.Weight() == r.Weight()) && (&l < &r));
    }

    double WeightStep() const noexcept {
        return WeightStep_;
    }

    bool IncUsageCount(ui32 countOfUsages = 1) noexcept {
        const double prevWeight = Weight_;
        Weight_ += countOfUsages * WeightStep();
        return Abs(Weight_ - prevWeight) > std::numeric_limits<double>::epsilon();
    }

    void DoOnCompleteRequest(const TDuration&) noexcept override {}

    void DoOnFailRequest(const TError&, const TDuration&) noexcept override {}

    double Weight() const noexcept {
        return Weight_;
    }

    double StaticWeight() const noexcept {
        return StaticWeight_;
    }

    void SetWeight(double val) {
        FileWeight_ = val;
        UpdateWeight(val);
    }

    void RestoreWeight() noexcept {
        UpdateWeight(OriginalWeight());
        FileWeight_ = OriginalWeight();
    }

    void RestoreFileWeight() noexcept {
        UpdateWeight(FileWeight_);
    }

    void PrintInfo(NJson::TJsonWriter& out) const noexcept {
        out.Write("original_weight", OriginalWeight());
        out.Write("weight", Weight_);
        out.Write("static_weight", StaticWeight_);
        PrintSuccFailRate(out);
        PrintProxyInfo(out);
    }

private:
    void UpdateWeight(double val) {
        StaticWeight_ = val;
        if (val > 0) {
            Weight_ = WeightStep_ = 1.0 / val;
            if (WeightStep_ < std::numeric_limits<double>::epsilon()) {
                ythrow TBackendException() << "bad new weight " << val << Endl; // TODO: fix me
            }
        } else { // disabling backend
            Weight_ = Max<double>();
            WeightStep_ = 0.0;
        }
    }

private:
    double FileWeight_ = 0.;
    double StaticWeight_ = 0.;
    double Weight_ = 0.;
    double WeightStep_ = 0.; // maybe we do not need it
};

}  // namespace

namespace NSrvKernel::NModBalancer {

BACKENDS_TLS(rr) {
    void FillHashedWeights() noexcept {
        Y_ASSERT(IsSorted(Names.begin(), Names.end()));

        HashedBackends.clear();
        double totalWeight = 0.0;
        for (const auto& name: Names) {
            auto* backend = NamedBackends[name];
            Y_ASSERT(backend);
            const double weight = backend->StaticWeight();
            if (weight > 0.0) {
                totalWeight += weight;
                HashedBackends.push_back(std::make_pair(totalWeight, backend));
            }
        }

        for (auto& hashed : HashedBackends) {
            hashed.first /= totalWeight;
        }
    }

    bool RandomizeOnFirstRequest = false;

    TIntrusiveListWithAutoDelete<TBackend, TDelete> List;
    TAvlTree<TBackend, TBackendCompare> Servers;
    THashMap<TString, TBackend*> NamedBackends;
    TVector<TString> Names;
    TVector<std::pair<double, TBackend*>> HashedBackends;

    TExternalWeightsFileReReader WeightsFileChecker;

    TSharedCounter DisablingAllBackendsUpdateCount;
};

BACKENDS_WITH_TLS(rr), public TModuleParams {
private:

    // NOTES:
    // backends should not have a tree, it must have only a list
    // tree has to be constucted for every algorithm instance
    class TRrAlgorithm : public TAlgorithmWithRemovals {
    public:
        TRrAlgorithm(TBackends* parent, const TStepParams& stepParams, TTls& tls) noexcept
            : TAlgorithmWithRemovals(&stepParams.Descr->Process())
            , Parent_(parent)
            , Tls_(tls)
        {}

        void RemoveSelected(IBackend* backend) noexcept override {
            DoRemove(backend);
            TAlgorithmWithRemovals::RemoveSelected(backend);
        }

        IBackend* Next() noexcept override {
            for (auto& backend: Tls_.Servers) {
                if (!IsRemoved(&backend)) {
                    return &backend;
                }
            }

            return nullptr;
        }

        void Select(IBackend* backend) noexcept override {
            Y_ASSERT(backend);
            Parent_->IncUsageCount(static_cast<TBackend*>(backend), Tls_);
        }

        IBackend* NextByName(TStringBuf name, bool allowZeroWeights) noexcept override {
            auto it = Tls_.NamedBackends.find(name);
            if (it != Tls_.NamedBackends.end()) {
                if (it->second && (allowZeroWeights || it->second->StaticWeight() > 0.0)) {
                    TAlgorithmWithRemovals::RemoveSelected(it->second);
                    return it->second;
                }
            }
            return nullptr;
        }

        IBackend* NextByHash(IHashProcessor& hash) noexcept override {
            if (Tls_.HashedBackends.empty()) {
                return nullptr;
            }

            if (hash.Has()) {
                struct TComparer { // TODO: return me in if
                    bool operator()(double lhs, const std::pair<double, TBackend*>& rhs) noexcept {
                        return lhs < rhs.first;
                    }
                    bool operator()(const std::pair<double, TBackend*>& lhs, double rhs) noexcept {
                        return lhs.first < rhs;
                    }
                };

                TRequestHash hashValue = hash.Get();
                double hashTarget = double(hashValue) / MaxCeil<TRequestHash>();

                auto it = std::upper_bound(Tls_.HashedBackends.begin(), Tls_.HashedBackends.end(), hashTarget, TComparer());
                if (it == Tls_.HashedBackends.end()) {
                    --it;
                }
                TAlgorithmWithRemovals::RemoveSelected(it->second);
                return it->second;
            } else {
                TBackend* const next = &*Tls_.Servers.Begin();
                double prev = 0.0;
                for (const auto& b : Tls_.HashedBackends) {
                    if (b.second == next) {
                        const TRequestHash acceptableHash = ClampDoubleToUI64(
                                (prev + (b.first - prev) * RandomNumber<double>()) * MaxCeil<TRequestHash>()
                        );

                        hash.Set(acceptableHash);
                        break;
                    }
                    prev = b.first;
                }
                TAlgorithmWithRemovals::RemoveSelected(next);
                return next;
            }
        }

    private:
        void DoRemove(IBackend* backend) noexcept {
            TBackend* parentBackend = dynamic_cast<TBackend*>(backend);
            Parent_->IncUsageCount(parentBackend, Tls_);
        }

    private:
        TBackends* Parent_ = nullptr;
        TTls& Tls_;
    };

// Initialization
// --------------------------------------------------------------------------------
public:
    TBackends(const TModuleParams& mp, const TBackendsUID& uid)
        : TBackendsWithTLS(mp)
        , TModuleParams(mp)
        , BackendsId_(uid.Value)
        , DisablingAllBackendsUpdateCount_(
            mp.Control->SharedStatsManager().MakeCounter(
                "rr-disabling_all_backends_update_count").AllowDuplicate().Build())
    {
        Config->ForEach(this);

        if (RandomizeInitialState_ && CountOfRandomizedRequestsOnWeightsApplication_ > 0) {
            ythrow TConfigParseError{} << "balancer2: both randomize_initial_state and "
                                          "count_of_randomized_requests_on_weights_application are set";
        }
    }

private:
    void ProcessPolicyFeatures(const TPolicyFeatures& features) override {
        if (UseHash_) {
            return;
        }
        UseHash_ = features.WantsHash;
    }

    THolder<TTls> DoInit(IWorkerCtl* process) noexcept override {
        auto tls = MakeHolder<TTls>();

        for (auto& i : BackendDescriptors()) {
            auto* const b = new TBackend(i);
            tls->List.PushBack(b);
            if (b->StaticWeight() > 0.0) {
                tls->Servers.Insert(b);
            }
            tls->NamedBackends[i->Name()] = b;
            tls->Names.emplace_back(i->Name());
        }

        if (UseHash_) {
            Sort(tls->Names);
            tls->FillHashedWeights();
        } else {
            tls->Names.clear();
        }

        if (RandomizeInitialState_) {
            tls->RandomizeOnFirstRequest = true;
        }

        if (!tls->RandomizeOnFirstRequest) {
            RandomizeInitialState(*tls.Get());
        }

        if (WeightsFilename_) {
            tls->WeightsFileChecker = TExternalWeightsFileReReader(*process, WeightsFilename_, TDuration::Seconds(1), NAME);
        }

        tls->DisablingAllBackendsUpdateCount = TSharedCounter(DisablingAllBackendsUpdateCount_, process->WorkerId());

        return tls;
    }

    START_PARSE {
        ON_KEY("weights_file", WeightsFilename_) {
            return;
        }

        ON_KEY("count_of_randomized_requests_on_weights_application", CountOfRandomizedRequestsOnWeightsApplication_) {
            return;
        }

        ON_KEY("randomize_initial_state", RandomizeInitialState_) {
            return;
        }

        Add(MakeHolder<TBackendDescriptor>(Copy(value->AsSubConfig()), key));

        return;
    } END_PARSE

// --------------------------------------------------------------------------------


    void DumpBackends(NJson::TJsonWriter& out, const TTls& tls) const noexcept override {
        out.OpenMap();
        out.Write("id", BackendsId_);
        out.OpenArray("backends");
        for (auto backend = tls.List.Begin(); backend != tls.List.End(); ++backend) {
            out.OpenMap();
            backend->PrintInfo(out);
            out.Write("update_weights_id", tls.WeightsFileChecker.LatestData().Id());
            out.CloseMap();
        }
        out.CloseArray();
        out.CloseMap();
    }

    void DumpWeightsFileTags(NJson::TJsonWriter& out, const TTls& tls) const noexcept override {
        tls.WeightsFileChecker.WriteWeightsFileTag(out);
    }


// Functionality
// --------------------------------------------------------------------------------
    THolder<IAlgorithm> ConstructAlgorithm(const TStepParams& params) noexcept override {
        const bool updated = UpdateWeights(*params.Descr);
        auto& tls = GetTls(params.Descr->Process());
        if (tls.RandomizeOnFirstRequest) {
            tls.RandomizeOnFirstRequest = false;
            if (!updated) { //else already randomized inside UpdateWeights
                RandomizeInitialState(tls);
            }
        }

        return MakeHolder<TRrAlgorithm>(this, params, tls);
    }

    void IncUsageCount(TBackend* backend, TTls& tls) noexcept {
        if (backend->StaticWeight() > 0.0) {
            auto foundBackend = tls.Servers.Erase(backend);
            if (foundBackend) {
                foundBackend->IncUsageCount();
                tls.Servers.Insert(foundBackend);
            }
        }
    }

    bool UpdateWeights(const TConnDescr& descr) noexcept {
        auto& tls = GetTls(descr.Process());
        if (tls.WeightsFileChecker.UpdateWeights(descr)) {
            const auto& entries = tls.WeightsFileChecker.Entries();
            if (IsAnyBackendAliveIfWeightsUpdated(entries, tls)) {
                SetWeights(entries, tls);
                return true;
            } else {
                ++tls.DisablingAllBackendsUpdateCount;
                LOG_ERROR(TLOG_ERR, descr,
                          "UpdateWeights Error because it disables all backends for weights_file: " << tls.WeightsFileChecker.LatestData().Data());
            }
        }
        return false;
    }

    template <typename TTls>
    bool IsAnyBackendAliveIfWeightsUpdated(const TMap<TString, double>& namedWeights, const TTls& tls) noexcept {
        for (auto i = tls.List.Begin(); i != tls.List.End(); ++i) {
            const auto named = namedWeights.find(i->Name());
            if ((named != namedWeights.end() && named->second > 0.0) ||
                (named == namedWeights.end() && i->OriginalWeight() > 0.0))
            {
                return true;
            }
        }
        return false;
    }

    template <typename TTls>
    void SetWeights(const TMap<TString, double>& namedWeights, TTls& tls) {
        while (!tls.Servers.Empty()) {
            tls.Servers.Erase(&*tls.Servers.First());
        }
        for (auto i = tls.List.Begin(); i != tls.List.End(); ++i) {
            const auto named = namedWeights.find(i->Name());
            if (named != namedWeights.end()) {
                i->SetWeight(named->second);
            } else {
                i->RestoreWeight();
            }
            if (i->StaticWeight() > 0.0) {
                tls.Servers.Insert(&*i);
            }
        }

        if (UseHash_) {
            tls.FillHashedWeights();
        }
        RandomizeInitialState(tls);
    }

    void RandomizeInitialState(TTls& tls) noexcept {
        if (RandomizeInitialState_ && !tls.Servers.Empty()) {
            // RandomNumber generates the same sequence for forked workers
            const size_t reqCount = GetCycleCount() % 1000;
            for (size_t i = 0; i < reqCount; ++i) {
                auto* backend = &*tls.Servers.First();
                IncUsageCount(backend, tls);
            }
        }

        if (CountOfRandomizedRequestsOnWeightsApplication_ == 0) {
            return;
        }

        double total = 0.0;
        for (const auto& backend: tls.List) {
            if (backend.StaticWeight() > 0.0) {
                total += backend.StaticWeight();
            }
        }

        for (auto& backend: tls.List) {
            if (backend.StaticWeight() > 0.0) {
                const ui32 part = static_cast<ui32>(
                    backend.StaticWeight() / total * CountOfRandomizedRequestsOnWeightsApplication_);
                if (part > 0) {
                    tls.Servers.Erase(&backend);
                    const ui32 howMany = RandomNumber(part);
                    backend.IncUsageCount(howMany);
                    tls.Servers.Insert(&backend);
                }
            }
        }
    }
// --------------------------------------------------------------------------------


// State
// --------------------------------------------------------------------------------
private:
    bool RandomizeInitialState_ = true;
    bool UseHash_ = false;
    ui32 CountOfRandomizedRequestsOnWeightsApplication_ = 0;
    TString WeightsFilename_;
    size_t BackendsId_ = 0;
    TSharedCounter DisablingAllBackendsUpdateCount_;
// --------------------------------------------------------------------------------
};

INodeHandle<IBackends>* NRRobin::Handle() {
    return TBackends::Handle();
}

}  // namespace NSrvKernel::NModBalancer
