#include "backends_factory.h"
#include "base_algorithm.h"
#include "weighted2.h"

#include <balancer/kernel/balancing/per_worker_backend.h>
#include <balancer/kernel/custom_io/stream.h>
#include <balancer/kernel/fs/kv_file_consumer.h>
#include <balancer/kernel/fs/shared_files.h>
#include <balancer/kernel/log/errorlog.h>
#include <balancer/kernel/memory/chunks.h>
#include <balancer/kernel/module/iface.h>
#include <balancer/kernel/module/module.h>

namespace NSrvKernel::NModBalancer {

namespace NWeighted2 {
using namespace Nweighted2;

TBackend::TBackend(TBackendDescriptor::TRef descr, NWeighted2::TCorrectionParams* correctionParams,
    TReplyTimeQuality* estimator, Nweighted2::TTls& tls)
    : TPerWorkerBaseBackend(std::move(descr))
    , CorrectionParams_(correctionParams)
    , Estimator_(estimator)
    , Tls_(tls)
    , ConfigWeight_(Backend_->Weight())
    , QualityData_(CorrectionParams_->HistoryTime_)
    , LastCheckedTime_(Now())
{
    RestoreWeight(); // To call SetWeight to make sure that disabled state is set properly
}

void TBackend::DoOnCompleteRequest(const TDuration& answerTime) noexcept {
    double quality = Estimator_->OnCompleteRequest(answerTime);
    QualityData_.Add(quality);
    Weight_ = Correct(quality, Tls_.GetAvgQuality(), Weight_ / ConfigWeight_) * ConfigWeight_;
    Tls_.Normalize();
}

void TBackend::DoOnFailRequest(const TError&, const TDuration& answerTime) noexcept {
    double quality = Estimator_->OnFailRequest(answerTime);
    QualityData_.Add(quality);
    Weight_ = Correct(quality, Tls_.GetAvgQuality(), Weight_ / ConfigWeight_) * ConfigWeight_;
    Tls_.Normalize();
}

void TBackend::SetWeight(double weight) noexcept {
    const bool wasDisabled{Disabled_};
    if (weight > 0.0) {
        ConfigWeight_ = weight;
        Disabled_ = false;
    } else {
        Disabled_ = true;
    }

    Weight_ = 1.0; // TODO: Weight_ = weight?
    CurrentWeight_ = 0.0;
    if (wasDisabled) {
        LastCheckedWeight_ = 1.0;
        LastCheckedTime_ = Now();
    }
}

void TBackend::Normalize(double c) noexcept {
    Weight_ *= c;

    Weight_ = Max(CorrectionParams_->MinWeight_ * ConfigWeight_, Weight_);
    Weight_ = Min(CorrectionParams_->MaxWeight_ * ConfigWeight_, Weight_);
}

void TBackend::PrintInfo(NJson::TJsonWriter& out) const noexcept {
    out.Write("weight", Weight_);
    out.Write("config_weight", ConfigWeight_);
    out.Write("disabled", Disabled_);
    out.Write("current_weight", CurrentWeight_);
    PrintSuccFailRate(out);
    PrintProxyInfo(out);
}

double TBackend::Qps() const noexcept {
    std::pair<double, size_t> newQuality = QualityData_.Get();

    // calculate qps
    double qps;
    if (newQuality.second < 2) {
        qps = 1.0;
    } else {
        qps = (newQuality.second - 1) / (Now() - QualityData_.StartT()).SecondsFloat();
    }
    return qps;
}

double TBackend::Correct(double quality, double avgQuality, double curWeight) noexcept {
    Y_ASSERT(quality >= 0. && quality <= 1.);
    Y_ASSERT(avgQuality >= 0. && avgQuality <= 1.);

    std::pair<double, size_t> newQuality = QualityData_.Get();

    // calculate qps
    double qps;
    if (newQuality.second < 2) {
        qps = 1.0;
    } else {
        qps = (newQuality.second - 1) / (Now() - QualityData_.StartT()).SecondsFloat();
    }

    // calculate new weight
    double newWeight;

    if (newQuality.first / newQuality.second / Max(avgQuality, 0.00001) > 1.) {
        newWeight =
            curWeight * ((newQuality.first / newQuality.second / Max(avgQuality, 0.00001) - 1.) *
                         CorrectionParams_->PlusDiffWeightPerSec_ / qps + 1.);
    } else {
        newWeight =
            curWeight * ((newQuality.first / newQuality.second / Max(avgQuality, 0.00001) - 1.) *
                         CorrectionParams_->MinusDiffWeightPerSec_ / qps + 1.);
    }

    // apply feedback
    newWeight =
        (1. - newWeight) / CorrectionParams_->FeedbackTime_.Seconds() / qps + newWeight;

    // apply restrictions
    newWeight = Min(newWeight,
                    Min(LastCheckedWeight_ * (1. + CorrectionParams_->PlusDiffWeightPerSec_),
                        CorrectionParams_->MaxWeight_));
    newWeight = Max(newWeight,
                    Max(LastCheckedWeight_ * (1. - CorrectionParams_->MinusDiffWeightPerSec_),
                        CorrectionParams_->MinWeight_));

    if ((Now() - LastCheckedTime_).MicroSeconds() > 1000000) {
        LastCheckedWeight_ = newWeight;
        LastCheckedTime_ = Now();
    }

    return newWeight;
}
}  // NWeighted2

using namespace NWeighted2;
BACKENDS_WITH_TLS(weighted2), public TModuleParams {
private:
    class TWeighted2Algorithm : public TAlgorithmWithRemovals {
    public:
        TWeighted2Algorithm(TBackends* parent, const TStepParams& params)
            : TAlgorithmWithRemovals(&params.Descr->Process())
            , Parent_(parent)
        {}

        IBackend* Next() noexcept override {
            auto& tls = Parent_->GetTls(*Process_);
            auto i = tls.Backends.begin();

            while (i != tls.Backends.end() && i->Disabled()) {
                ++i;
            }

            if (i == tls.Backends.end()) {
                return nullptr;
            }

            TBackend* retval = &*i;
            double max = i->GetCurrentWeight();

            double total = 0.;
            for (; i != tls.Backends.end(); ++i) {
                if (i->Disabled()) {
                    continue;
                }

                if (IsRemoved(&*i)) {
                    continue;
                }
                i->AddCurrentWeight(i->GetWeight());
                total += i->GetWeight();

                const double cur = i->GetCurrentWeight();

                if (cur > max) {
                    max = cur;
                    retval = &*i;
                }
            }

            if (IsRemoved(retval)) {
                // TODO: erase currentWeight?
                return nullptr;
            } else {
                retval->AddCurrentWeight(-total);

                return retval;
            }
        }

        IBackend* NextByName(TStringBuf name, bool /*allowZeroWeights*/) noexcept override {
            auto it = Parent_->GetTls(*Process_).NamedBackends.find(name);
            if (it != Parent_->GetTls(*Process_).NamedBackends.end()) {
                return it->second;
            }
            return nullptr;
        }

    private:
        TBackends* Parent_ = nullptr;
    };

// Initialization
// --------------------------------------------------------------------------------
public:
    TBackends(const TModuleParams& mp, const TBackendsUID& uid)
        : TBackendsWithTLS(mp)
        , TModuleParams(mp)
        , BackendsId_(uid.Value)
    {
        Config->ForEach(this);
        Estimator_.Reset(MakeHolder<TReplyTimeQuality>(SlowReplyTime_));
    }

private:
    THolder<TTls> DoInit(IWorkerCtl* process) noexcept override {
        auto tls = MakeHolder<TTls>();

        for (auto& i : BackendDescriptors()) {
            tls->Backends.emplace_back(i, &CorrectionParams_, Estimator_.Get(), *tls.Get());
        }
        for (auto& i : tls->Backends) {
            tls->NamedBackends[i.Name()] = &i;
        }

        tls->RandomizeInitialState();

        if (!!WeightsFilename_) {
            tls->WeightsFileChecker = process->SharedFiles()->FileReReader(WeightsFilename_, TDuration::Seconds(1));
        }

        return tls;
    }

    void ProcessPolicyFeatures(const TPolicyFeatures& features) override {
        if (features.WantsHash) {
            PrintOnce("WARNING in balancer2/weighted1: policies with hash are not supported and will be ignored");
        }
    }

    START_PARSE {
            if (key == "correction_params") {
                value->AsSubConfig()->ForEach(&CorrectionParams_);
                return;
            }

            ON_KEY("slow_reply_time", SlowReplyTime_) {
                return;
            }

            ON_KEY("weights_file", WeightsFilename_) {
                return;
            }

            Add(MakeHolder<TBackendDescriptor>(Copy(value->AsSubConfig()), key));

            return;
        } END_PARSE

// --------------------------------------------------------------------------------


    void DumpBackends(NJson::TJsonWriter& out, const TTls& tls) const noexcept override {
        out.OpenMap();
        out.Write("id", BackendsId_);
        out.OpenArray("backends");
        for (const auto& backend : tls.Backends) {
            out.OpenMap();
            backend.PrintInfo(out);
            out.CloseMap();
        }
        out.CloseArray();
        out.CloseMap();
    }


// Functionality
// --------------------------------------------------------------------------------
    THolder<IAlgorithm> ConstructAlgorithm(const TStepParams& params) noexcept override {
        UpdateWeights(*params.Descr);

        return MakeHolder<TWeighted2Algorithm>(this, params);
    }

    void UpdateWeights(const TConnDescr& descr) noexcept {
        auto& tls = GetTls(descr.Process());
        const auto& data = tls.WeightsFileChecker.Data();
        if (data.Id() != tls.WeightsFileData.Id()) {
            tls.WeightsFileData = data;
            ReReadWeights(tls.WeightsFileData.Data(), descr);
        }
    }

    void ReReadWeights(const TStringBuf contents, const TConnDescr& descr) noexcept {
        NSrvKernel::TExternalWeightsFileConsumer consumer;
        try {
            ProcessKvData(consumer, contents);
            GetTls(descr.Process()).SetWeights(consumer.Storage());
        } catch (const yexception& e) {
            LOG_ERROR(TLOG_ERR, descr, "ReReadWeights Error parsing weight for weights_file: " << e.what());
        }
    }
// --------------------------------------------------------------------------------


// State
// --------------------------------------------------------------------------------
private:
    TCorrectionParams CorrectionParams_;
    THolder<TReplyTimeQuality> Estimator_;
    TDuration SlowReplyTime_ = TDuration::MicroSeconds(500000);
    TString WeightsFilename_;
    size_t BackendsId_ = 0;
// --------------------------------------------------------------------------------
};

}

namespace NSrvKernel::NModBalancer {


INodeHandle<IBackends>* NWeighted2::Handle() {
    return TBackends::Handle();
}

}  // namespace NSrvKernel::NModBalancer
