#include "module.h"

#include <balancer/kernel/coro/cleanable_coro_storage.h>
#include <balancer/kernel/custom_io/rewind.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/memory/chunks.h>
#include <balancer/kernel/module/iface.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/net/address.h>
#include <balancer/kernel/requester/requester.h>

#include <util/random/easy.h>

using namespace NConfig;
using namespace NSrvKernel;

Y_TLS(request_replier) {
    explicit TTls(double configRate)
        : Rate(configRate)
        , ConfigRate(configRate)
    {}

    void UpdateRate() noexcept {
        const auto& data = RateFileReader.Data();
        if (data.Id() == Id) {
            return;
        }
        Id = data.Id();
        auto stripped = TStringBuf(data.Data());
        if (!stripped.empty() && TryFromStringWithDefault(StripString(stripped), Rate, ConfigRate)) {
            if (Rate < 0.0) {
                Rate = ConfigRate;
            }
        } else {
            Rate = ConfigRate;
        }
    }

    size_t CountCopiesFor(const TConnDescr& descr) {
        const TRequest* const request = descr.Request;

        if (!request || request->Props().UpgradeRequested) {
            return 0;
        }

        //TODO(nocomer): here, we can restrict copies of non-idempotent requests

        double floor = std::floor(Rate);
        double rand = Random();
        size_t count = static_cast<size_t>(floor);
        if (Rate > floor + rand) {
            ++count;
        }
        return count;
    }

    TSharedFileReReader RateFileReader;
    size_t Id = 0;
    TCleanableCoroStorage Repliers;
    double Rate = 0.0;
    const double ConfigRate = 0.0;
};

MODULE_WITH_TLS(request_replier) {
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
    {
        Config->ForEach(this);

        if (!MainModule_) {
            ythrow TConfigParseError() << "submodule must be provided";
        }

        if (!SinkModule_) {
            ythrow TConfigParseError() << "\"sink\" section with submodule must be provided";
        }
    }

    START_PARSE {
        ON_KEY("rate", ConfigRate_) {
            if (ConfigRate_ < 0.0) {
                ythrow TConfigParseError() << "\"rate\" must be not less than 0.0";
            }
            return;
        }

        if (key == "sink") {
            TSubLoader(Copy(value->AsSubConfig())).Swap(SinkModule_);
            return;
        }

        ON_KEY("rate_file", RateFileName_) {
            return;
        }

        ON_KEY("enable_failed_requests_replication", FailedRequestsReplicationEnabled_) {
            return;
        }

        MainModule_ = Loader->MustLoad(key, Copy(value->AsSubConfig()));
        return;
    } END_PARSE

    TError DoRun(const TConnDescr& descr, TTls& tls) const noexcept override {
        tls.Repliers.EraseFinished();

        tls.UpdateRate();
        size_t count = tls.CountCopiesFor(descr);
        if (count > 0) {
            return MakeSinkRequest(descr, tls, count);
        }
        return MainModule_->Run(descr);
    }

    TError MakeSinkRequest(const TConnDescr& descr, TTls& tls, size_t count) const {
        // TODO(velavokr): we are copying the request here because other modules might corrupt it
        TRequest requestCopy = *descr.Request;

        // TODO(velavokr): BALANCER-2328 replicating the previous behavior for now, should rewrite later
        TLimitedRewindableInput in(*descr.Input, Max<size_t>());
        TConnDescr mainDescr = descr.CopyIn(in);

        TError ret;

        Y_TRY(TError, error) {
            return MainModule_->Run(mainDescr);
        } Y_CATCH {
            if (error.GetAs<TForceStreamClose>() || FailedRequestsReplicationEnabled_) {
                ret = std::move(error);
            } else {
                return error;
            }
        }

        in.Rewind();
        TChunkList body = in.RecvBuffered();

        for (size_t i = 0; i < count; ++i) {
            tls.Repliers.Emplace("replier_cont", &descr.Process().Executor(),
                 [&module = *SinkModule_, &process = descr.Process()](TRequest request, TChunkList body, TRequestHash hash) {
                     TAsyncRequester requester{module, nullptr, process, hash};
                     Y_UNUSED(requester.Requester().Request(std::move(request), std::move(body), false));
                 },
                 (i + 1 == count) ? std::move(requestCopy) : requestCopy,
                 (i + 1 == count) ? std::move(body) : body.Copy(),
                 descr.Hash
            );
        }

        return ret;
    }

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        auto tls = MakeHolder<TTls>(ConfigRate_);
        if (RateFileName_) {
            tls->RateFileReader = process->SharedFiles()->FileReReader(RateFileName_, TDuration::Seconds(1));
        }
        return tls;
    }

private:
    TString RateFileName_;
    double ConfigRate_ = 0.0;
    THolder<IModule> MainModule_;
    THolder<IModule> SinkModule_;
    bool FailedRequestsReplicationEnabled_ = false;
};

IModuleHandle* NModRequestReplier::Handle() {
    return TModule::Handle();
}
