#include "module.h"

#include <balancer/kernel/coro/cleanable_coro_storage.h>
#include <balancer/kernel/http/parser/common_headers.h>
#include <balancer/kernel/http/parser/httpencoder.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/requester/requester.h>

using namespace NSrvKernel;
using namespace NModRpsLimiter;

class TQuotaNameFsm : public NRegExp::TFsm, public TWithDefaultInstance<TQuotaNameFsm> {
public:
    TQuotaNameFsm()
        : TFsm("[-a-z0-9_]+", TFsm::TOptions())
    {}

    static bool Match(const TString& s) {
        return NRegExp::TMatcher(Instance()).Match(s).Final();
    }
};

namespace {
    constexpr TStringBuf QUOTA_HEADER = "X-Yandex-Rpslimiter-Matched-Quota"sv;
    constexpr TStringBuf FORWARD_TO_USER = "X-ForwardToUser-Y"sv;
}

Y_TLS(rps_limiter) {
    bool DisableFileExists() const noexcept {
        return DisableFileChecker.Exists();
    }

    TSharedFileExistsChecker DisableFileChecker;
    TCleanableCoroStorage Runners;
};

MODULE_WITH_TLS_BASE(rps_limiter, TModuleWithSubModule) {
public:
    TModule(const TModuleParams& mp) : TModuleBase(mp)
    {
        Config->ForEach(this);
        if (!Submodule_) {
            ythrow TConfigParseError() << "no module configured";
        }

        if (!Checker_) {
            ythrow TConfigParseError() << "no checker configured";
        }

        if (OnError_ && SkipOnError_) {
            ythrow TConfigParseError() << "try to use both on_error and skip_on_error";
        }

        if (QuotaName_ && !TQuotaNameFsm::Match(*QuotaName_)) {
            ythrow TConfigParseError() << "Invalid quota name \"" + *QuotaName_ + "\"";
        }

        if (RegisterBackendAttempts_ && !RegisterOnly_) {
            ythrow TConfigParseError() << "Using register_backend_attempts without register_only";
        }

        InitCheckerRequest();
    }

    private:
        START_PARSE {
            ON_KEY("disable_file", DisableFile_) {
                return;
            }

            ON_KEY("skip_on_error", SkipOnError_) {
                return;
            }

            ON_KEY("quota_name", QuotaName_) {
                return;
            }

            ON_KEY("register_only", RegisterOnly_) {
                return;
            }

            ON_KEY("register_backend_attempts", RegisterBackendAttempts_) {
                return;
            }

            TString namespaceName;
            ON_KEY("namespace", namespaceName) {
                Namespace_ = namespaceName;
                return;
            }

            ON_KEY("log_quota", LogQuota_) {
                return;
            }

            ON_KEY("fail_on_quota_deny", FailOnQuotaDeny_) {
                return;
            }

            if (key == "checker") {
                TSubLoader(Copy(value->AsSubConfig())).Swap(Checker_);
                return;
            } else if (key == "module") {
                TSubLoader(Copy(value->AsSubConfig())).Swap(Submodule_);
                return;
            }  else if (key == "on_error") {
                TSubLoader(Copy(value->AsSubConfig())).Swap(OnError_);
                return;
            }  else if (key == "on_limited_request") {
                TSubLoader(Copy(value->AsSubConfig())).Swap(OnLimited_);
                return;
            }
        } END_PARSE

    void InitCheckerRequest() {
        TString request = "POST /quota.acquire";
        if (QuotaName_) {
            request += "?quota=" + *QuotaName_;
        }
        request += " HTTP/1.1\r\n\r\n";
        TryRethrowError(CheckerRequest_.Parse(request));
        CheckerRequest_.Headers().Add("Host", "localhost");
    }

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        auto tls = MakeHolder<TTls>();
        if (!!DisableFile_) {
            tls->DisableFileChecker = process->SharedFiles()->FileChecker(DisableFile_, TDuration::Seconds(1));
        }
        return tls;
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const noexcept override {
        tls.Runners.EraseFinished();
        if (tls.DisableFileExists()) {
            return Submodule_->Run(descr);
        }

        if (!RegisterOnly_) {
            return CheckedRun(descr);
        }

        if (!RegisterBackendAttempts_) {
            tls.Runners.Emplace("async_rpslimiter_request", &descr.Process().Executor(),
                [this, &process = descr.Process(), start=descr.Properties->Start]
                    (TRequest&& request, TChunksOutputStream&& requestStream)
                {
                    TAsyncRequester requester{*Checker_, nullptr, process, 0, start};
                    Y_UNUSED(requester.Requester().Request(std::move(request), std::move(requestStream.Chunks()), false));
                },
                CheckerRequest_,
                BuildRequestBody(descr)
            );
            return Submodule_->Run(descr);
        } else {
            Y_PROPAGATE_ERROR(Submodule_->Run(descr));
            size_t backendAttempts = descr.Properties->ConnStats.BackendAttempt;
            tls.Runners.Emplace("async_rpslimiter_requests", &descr.Process().Executor(),
                [this, &process = descr.Process(), backendAttempts] (TRequest&& request, TChunksOutputStream&& requestStream) {
                    TAsyncRequester requester{*Checker_, nullptr, process};
                    for (size_t i = 0; i < backendAttempts; ++i) {
                        Y_UNUSED(requester.Requester().Request(TRequest(request), requestStream.Chunks().Copy(), false));
                    }
                },
                CheckerRequest_,
                BuildRequestBody(descr)
            );
        }

        return {};
    }

    bool DoCanWorkWithoutHTTP() const noexcept override {
        return true;
    }

    TChunksOutputStream BuildRequestBody(const TConnDescr& descr) const {
        TChunksOutputStream requestStream;
        descr.Request->BuildTo(requestStream);
        if (Namespace_.Defined()) {
            requestStream << "X-Rpslimiter-Balancer" << HDRD << *Namespace_ << CRLF;
        }
        if (LogQuota_) {
            requestStream << "X-Yandex-Rpslimiter-Log-Quota" << HDRD << "1" << CRLF;
        }
        requestStream << CRLF;
        return requestStream;
    }

    TError CheckedRun(const TConnDescr& descr) const noexcept {
        bool allow = false;
        TResponse response;
        TChunkList body;
        Y_TRY(TError, error) {
            TRequester requester(*Checker_, descr);
            TRequest request = CheckerRequest_;
            TChunksOutputStream requestStream = BuildRequestBody(descr);

            Y_PROPAGATE_ERROR(requester.Request(std::move(request), std::move(requestStream.Chunks()), false, response, body));

            if (LogQuota_) {
                TStringBuf quotaNameHeaderValue = response.Headers().GetFirstValue(QUOTA_HEADER);
                descr.ExtraAccessLog << " quota:" << (quotaNameHeaderValue ?: "(no_quota)");
            }

            TStringBuf forwardToUserHeaderValue = response.Headers().GetFirstValue(FORWARD_TO_USER);
            if (!forwardToUserHeaderValue || !Match(TTrueFsm::Instance(), forwardToUserHeaderValue)) {
                allow = true;
            }

            descr.ExtraAccessLog << (allow ? " allow" : " deny");

            return {};
        } Y_CATCH {
            descr.ExtraAccessLog << " error";
            if (descr.AttemptsHolder) {
                descr.AttemptsHolder->NotifyRpsLimiterError();
            }
            if (OnError_) {
                error = OnError_->Run(descr);
                TAccessLogSummary *summary = descr.ExtraAccessLog.Summary();
                if (summary) {
                    summary->AnsweredModule = GetHandle()->Name() + " | " + summary->AnsweredModule;
                    summary->AnswerReason = "rps limiter fail, on_error | " + summary->AnswerReason;
                }
                return error;
            } else if (!SkipOnError_) {
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "rps limiter fail");
                return Y_MAKE_ERROR(yexception{} << "rps limiter fail");
            }
            return Submodule_->Run(descr);
        }

        if (allow) {
            return Submodule_->Run(descr);
        }

        if (OnLimited_) {
            TError error = OnLimited_->Run(descr);
            TAccessLogSummary *summary = descr.ExtraAccessLog.Summary();
            if (summary) {
                summary->AnsweredModule = GetHandle()->Name() + " | " + summary->AnsweredModule;
                summary->AnswerReason = "quota deny, on_limited_request | " + summary->AnswerReason;
            }
            return error;
        }

        descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "quota deny");
        Y_TRY(TError, error) {
            response.Headers().Delete(FORWARD_TO_USER);
            response.Headers().Delete(QUOTA_HEADER);
            Y_PROPAGATE_ERROR(descr.Output->SendHead(std::move(response), false, TInstant::Max()));
            if (!body.Empty()) {
                Y_PROPAGATE_ERROR(descr.Output->Send(std::move(body), TInstant::Max()));
            }
            return descr.Output->SendEof(TInstant::Max());
        } Y_CATCH {
            descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "client write error");
        }

        if (FailOnQuotaDeny_) {
            return Y_MAKE_ERROR(yexception{} << "rps limiter quota deny");
        }
        return {};
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

    TMaybe<TString> QuotaName_;
    THolder<IModule> Checker_;
    THolder<IModule> OnError_;
    THolder<IModule> OnLimited_;
    TString DisableFile_;
    TMaybe<TString> Namespace_;

    bool SkipOnError_ = false;
    bool RegisterOnly_ = false;
    bool RegisterBackendAttempts_ = false;
    bool FailOnQuotaDeny_ = false;

    bool LogQuota_ = true;

    TRequest CheckerRequest_;
};

IModuleHandle* NModRpsLimiter::Handle() {
    return TModule::Handle();
}
