#pragma once

// TODO(carzil): move this to helpers?
#include <balancer/kernel/balancer/average.h>
#include <balancer/kernel/balancing/per_worker_backend.h>
#include <balancer/kernel/balancer/backends.h>
#include <balancer/kernel/module/iface.h>

namespace NSrvKernel {

class IBackends;
template <class> class INodeHandle;

namespace NModBalancer::Nweighted2 {

constexpr char NAME[]="weighted2";
struct TTls;
struct TBackends;

}  // namespace NModBalancer::Nweighted2

namespace NModBalancer::NWeighted2 {

class TReplyTimeQuality {
public:
    TReplyTimeQuality(const TDuration& slowReplyTime) noexcept
        : SlowReplyTime_(slowReplyTime)
    {}

    double OnCompleteRequest(const TDuration& replyTime) const noexcept {
        if (replyTime > SlowReplyTime_) {
            return 0.;
        }
        return 1. / (1. + replyTime.MicroSeconds() / float(SlowReplyTime_.MicroSeconds()));
    }

    double OnFailRequest(const TDuration& /*replyTime*/) const noexcept {
        return 0.;
    }

private:
    const TDuration SlowReplyTime_;
};

class TCorrectionParams : public IConfig::IFunc {
public:
    TCorrectionParams()
        : MinWeight_(0.05), MaxWeight_(5.), PlusDiffWeightPerSec_(0.05), MinusDiffWeightPerSec_(0.1),
          HistoryTime_(TDuration::MicroSeconds(20000000)),
          FeedbackTime_(TDuration::MicroSeconds(1000000 * 100)) {}

private:
    START_PARSE
        {
            ON_KEY("min_weight", MinWeight_) {
                return;
            }

            ON_KEY("max_weight", MaxWeight_) {
                return;
            }

            ON_KEY("plus_diff_per_sec", PlusDiffWeightPerSec_) {
                return;
            }

            ON_KEY("minus_diff_per_sec", MinusDiffWeightPerSec_) {
                return;
            }

            ON_KEY("history_time", HistoryTime_) {
                return;
            }

            ON_KEY("feedback_time", FeedbackTime_) {
                return;
            }
        }
    END_PARSE

public:
    double MinWeight_ = 0.;
    double MaxWeight_ = 0.;
    double PlusDiffWeightPerSec_ = 0.;
    double MinusDiffWeightPerSec_ = 0.;
    TDuration HistoryTime_;
    TDuration FeedbackTime_;
};

class TBackend : public TPerWorkerBaseBackend {
public:
    TBackend(TBackendDescriptor::TRef descr, TCorrectionParams* correctionParams,
            TReplyTimeQuality* estimator, Nweighted2::TTls& tls);

    void DoOnCompleteRequest(const TDuration& answerTime) noexcept override;

    void DoOnFailRequest(const TError&, const TDuration& answerTime) noexcept override;

    bool Disabled() const noexcept {
        return Disabled_;
    }

    double GetBackendWeight() const noexcept {
        return ConfigWeight_;
    }

    double GetWeight() const noexcept {
        return Weight_;
    }

    void SetWeight(double weight) noexcept;

    void RestoreWeight() noexcept {
        SetWeight(OriginalWeight());
    }

    void Normalize(double c) noexcept;

    double GetCurrentWeight() const noexcept {
        return CurrentWeight_;
    }

    void AddCurrentWeight(double extra) noexcept {
        CurrentWeight_ += extra;
    }

    std::pair<double, size_t> GetQuality() const noexcept {
        return QualityData_.Get();
    }

    void PrintInfo(NJson::TJsonWriter& out) const noexcept;

    double Qps() const noexcept;

private:
    double Correct(double quality, double avgQuality, double curWeight) noexcept;

private:
    const TCorrectionParams* CorrectionParams_ = nullptr;
    const TReplyTimeQuality* Estimator_ = nullptr;
    Nweighted2::TTls& Tls_;
    double ConfigWeight_ = 0.0;
    double Weight_ = 1.0;
    double CurrentWeight_ = 0.0;
    TTimeLimitedAvgTracker<1000> QualityData_;
    double LastCheckedWeight_ = 1.0;
    TInstant LastCheckedTime_;
    bool Disabled_ = false;
};


INodeHandle<IBackends>* Handle();

}  // namespace NModBalancer::NWeighted2

namespace NModBalancer::Nweighted2 {

struct TTls {
    void SetWeights(const TMap<TString, double>& namedWeights) noexcept {
        for (auto& backend : Backends) {
            const auto named = namedWeights.find(backend.Name());
            if (named != namedWeights.end()) {
                backend.SetWeight(named->second);
            } else {
                backend.RestoreWeight();
            }
        }
    }

    double GetAvgQuality() const noexcept {
        double sum = 0.;
        size_t total = 0;

        for (const auto& backend : Backends) {
            if (!backend.Disabled()) {
                const std::pair<double, size_t> r = backend.GetQuality();
                sum += r.first;
                total += r.second;
            }
        }

        if (total) {
            sum = Max(Min(double(total), sum), 0.); // avoid calculation errors
            return sum / total;
        } else {
            return 1.; //??
        }
    }

    void Normalize() noexcept {
        double sum = 0.;
        double startSum = 0.;

        for (const auto& backend : Backends) {
            if (!backend.Disabled()) {
                sum += backend.GetWeight();
                startSum += backend.GetBackendWeight();
            }
        }

        for (auto& backend : Backends) {
            if (!backend.Disabled()) {
                backend.Normalize(startSum / sum);
            }
        }
    }

    void RandomizeInitialState() noexcept {
        for (auto& b : Backends) {
            b.AddCurrentWeight(RandomNumber<double>() * b.GetWeight());
        }
    }

    TDeque<NWeighted2::TBackend> Backends;
    THashMap<TString, NWeighted2::TBackend*> NamedBackends;

    TSharedFileReReader WeightsFileChecker;
    TSharedFileReReader::TData WeightsFileData;
};
}  // namespace NModBalancer::Nweighted2

}  // namespace NSrvKernel
