#include "module.h"

#include <balancer/kernel/balancer/average.h>
#include <balancer/kernel/coro/coro_cond_var.h>
#include <balancer/kernel/custom_io/null.h>
#include <balancer/kernel/custom_io/stream.h>
#include <balancer/kernel/fs/shared_file_exists_checker.h>
#include <balancer/kernel/fs/shared_files.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/net/address.h>
#include <balancer/kernel/thread/threadedqueue.h>

using namespace NSrvKernel;
using namespace NModSmartPinger;

Y_TLS(smart_pinger) {
    THolder<TTimeLimitedAvgTracker<1000>> Stats;
    THolder<TTimeLimitedAvgTracker<1000>> PingStats;
    TCoroSingleCondVar PingCV;
    TString PingDisableFile;
    TSharedFileExistsChecker PingDisableChecker;
    bool Enabled = true;
    TCoroutine PingTask;
};

MODULE_WITH_TLS_BASE(smart_pinger, TModuleWithSubModule) {
private:
    void Ping(IWorkerCtl& process, TTls& tls) const {
        TRandomAddr addrOwner;
        TAddrHolder addr{ &addrOwner };
        TTcpConnProps tcpConnProps(process, addr, addr, nullptr);
        tcpConnProps.SkipKeepalive = true;
        TConnProps properties(tcpConnProps, Now(), 0);

        TRequest request = PingRequest_;

        TNullStream nullStream;
        TConnDescr descr(nullStream, nullStream, properties);
        descr.Request = &request;
        descr.Hash = RandomNumber<decltype(descr.Hash)>();

        try {
            Y_UNUSED(RunModule(descr, tls, true));
        } catch (...) {}
    }

    void ParsePingRequest(TRequest& pingRequest) {
        try {
            TryRethrowError(pingRequest.Parse(PingRequestData_));
        } catch (...) {
            ythrow TConfigParseError() << "bad ping request: " << CurrentExceptionMessage();
        }
    }

public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
    {
        Config->ForEach(this);

        if (TtlEnable_ < Delay_) {
            ythrow TConfigParseError() << "histtime smaller than delay";
        }

        if (!Submodule_) {
            ythrow TConfigParseError() << "no module configured";
        }

        ParsePingRequest(PingRequest_);
    }

private:
    START_PARSE {
        // TODO: maybe headers
        ON_KEY ("ping_request_data", PingRequestData_) {
            return;
        }

        ON_KEY("delay", Delay_) {
            return;
        }

        ON_KEY("ttl", TtlEnable_) {
            return;
        }

        ON_KEY("lo", LowSuccRate_) {
            if (0.0 > LowSuccRate_ || LowSuccRate_ > 1.0) {
                ythrow TConfigParseError() << "\"lo\" must be in [0.0, 1.0]";
            }
            return;
        }

        ON_KEY("hi", HiSuccRate_) {
            if (0.0 > HiSuccRate_ || HiSuccRate_ > 1.0) {
                ythrow TConfigParseError() << "\"hi\" must be in [0.0, 1.0]";
            }
            return;
        }

        ON_KEY("ping_disable_file", PingDisableFile_) {
            return;
        }

        ON_KEY("min_samples_to_disable", MinSamplesToDisable_) {
            if (MinSamplesToDisable_ > 1000) {
                ythrow TConfigParseError() << "\"min_samples_to_disable\" must be in [0, 1000]";
            }
            return;
        }

        if (key == "on_disable") {
            TSubLoader(Copy(value->AsSubConfig())).Swap(OnDisable_);
            return;
        }

        Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
        return;
    } END_PARSE

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        THolder<TTls> tls = MakeHolder<TTls>();

        if (process->WorkerType() != NProcessCore::TChildProcessType::Default) {
            return tls;
        }

        if (!!PingDisableFile_) {
            tls->PingDisableChecker = process->SharedFiles()->FileChecker(PingDisableFile_, TDuration::Seconds(1));
        }

        tls->Stats.Reset(new TTimeLimitedAvgTracker<1000>(TtlEnable_));
        tls->PingStats.Reset(new TTimeLimitedAvgTracker<1000>(TtlEnable_));

        tls->PingTask = TCoroutine{"pinger", &process->Executor(), [this, process, tlsPtr = tls.Get()] {
            auto* const cont = process->Executor().Running();
            // TODO(velavokr): not really smart. Will ping only after user traffic sees errors
            do {
                if (Enabled(*tlsPtr)) {
                    if (tlsPtr->PingCV.wait(&process->Executor()) == ECANCELED) {
                        break;
                    }
                } else {
                    Ping(*process, *tlsPtr);
                    cont->SleepT(Delay_);
                }
            } while (!cont->Cancelled());
        }};

        return tls;
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const noexcept override {
        if (Enabled(tls)) {
            Y_PROPAGATE_ERROR(RunModule(descr, tls, false));
        } else if (HasOnDisable()) {
            Y_PROPAGATE_ERROR(OnDisable_->Run(descr));
        } else {
            descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "ping disabled");
            return Y_MAKE_ERROR(TBackendError{Y_MAKE_ERROR(TPingFailedException{})});
        }
        return {};
    }

    TError RunModule(const TConnDescr& descr, TTls& tls, bool isPing) const {
        Y_TRY(TError, error) {
            return Submodule_->Run(descr);
        } Y_CATCH {
            if (error.GetAs<TForceStreamClose>()) {
                RegisterSuccess(tls, isPing);
            } else {
                RegisterFailure(tls, isPing);
            }
            return error;
        }
        RegisterSuccess(tls, isPing);
        return {};
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

private:
    bool Enabled(TTls& tls) const noexcept {
        if (tls.PingDisableChecker.Exists()) {
            return true;
        } else {
            UpdateState(tls);
            return tls.Enabled;
        }
    }

    bool HasOnDisable() const noexcept {
        return !!OnDisable_;
    }

    void RegisterSuccess(TTls& tls, bool isPing) const noexcept {
        Y_ASSERT(tls.Stats);
        Y_ASSERT(tls.PingStats);
        tls.Stats->Add(1.0);
        if (isPing) {
            tls.PingStats->Add(1.0);
        }
        UpdateState(tls);
    }

    void RegisterFailure(TTls& tls, bool isPing) const noexcept {
        Y_ASSERT(tls.Stats);
        Y_ASSERT(tls.PingStats);
        tls.Stats->Add(0.0);
        if (isPing) {
            tls.PingStats->Add(0.0);
        }
        UpdateState(tls);
    }

    void UpdateState(TTls& tls) const noexcept {
        const double weight = CalcWeight(tls);
        if (tls.Enabled && weight < LowSuccRate_) {
            tls.Enabled = false;
            tls.PingCV.notify();
        } else if (!tls.Enabled && weight >= HiSuccRate_) {
            tls.Enabled = true;
        }
    }

    double CalcWeight(TTls& tls) const noexcept {
        const auto stats = tls.Stats->Get();
        const auto pingStats = tls.PingStats->Get();
        const size_t total = stats.second + pingStats.second;
        if (total < MinSamplesToDisable_) {
            return 1.0;
        }
        const double mult = double(stats.second) / Max<size_t>(pingStats.second, 1);
        const double norm = Max<size_t>(mult + pingStats.second, 1);
        const double weight = (M(stats) * mult + M(pingStats) * pingStats.second) / norm;

        return weight;
    }

    static double M(const std::pair<double, size_t>& data) noexcept {
        if (data.second > 0) {
            return data.first / data.second;
        } else {
            return 1.0;
        }
    }

private:
    TString PingRequestData_{ "GET /robots.txt HTTP/1.1\r\nHost: yandex.ru\r\n\r\n" };
    TDuration Delay_{ TDuration::Seconds(5) };
    TDuration TtlEnable_{ TDuration::Seconds(10) };

    double LowSuccRate_{ 0.5 };
    double HiSuccRate_{ 0.7 };

    TString PingDisableFile_;

    THolder<IModule> OnDisable_;

    size_t MinSamplesToDisable_{ 200 };

    TRequest PingRequest_;
};

IModuleHandle* NModSmartPinger::Handle() {
    return TModule::Handle();
}
