#include "module.h"

#include <balancer/modules/rate_limiter/rate_limiter.cfgproto.pb.h>

#include <balancer/kernel/http/parser/response_builder.h>
#include <balancer/kernel/module/conn_descr.h>
#include <balancer/kernel/module/module.h>

#include <library/cpp/config/config.h>
#include <library/cpp/http/misc/httpcodes.h>
 
#include <util/datetime/base.h>
#include <util/generic/maybe.h>
#include <util/generic/scope.h>
#include <util/system/mutex.h>


using namespace NSrvKernel;

namespace {
    enum ETakeResult {
        Ok,
        Canceled,
        TooManyRequests,
    };

    class TRateLimiter {
    public:
        TRateLimiter(const NModRateLimiter::TModuleConfig& config)
            : Config_(config)
        {
            if (Config_.max_requests() == 0) {
                ythrow NConfig::TConfigParseError{} << "max_requests should be positive";
            }
        }

        ETakeResult Take(TCont* cont) noexcept {
            const auto delay = Config_.interval() / Config_.max_requests();
            TMaybe<TInstant> sleepUntil;

            {
                TGuard guard{Mutex_};
                auto now = TInstant::Now();

                if (NextTakeAt_ < now) {
                    NextTakeAt_ = now;
                }

                if (now < NextTakeAt_) {
                    if (Pending_ >= Config_.max_requests_in_queue()) {
                        return ETakeResult::TooManyRequests;
                    }

                    sleepUntil = NextTakeAt_;
                    Pending_++;
                }

                NextTakeAt_ += delay;
            }

            if (sleepUntil.Defined()) {
                Y_DEFER {
                    TGuard guard{Mutex_};
                    Pending_--;
                };

                if (cont->SleepD(sleepUntil.GetRef()) == ECANCELED) {
                    return ETakeResult::Canceled;
                }
            }

            return ETakeResult::Ok;
        }

    private:
        const NModRateLimiter::TModuleConfig& Config_;
        TMutex Mutex_;
        TInstant NextTakeAt_ = TInstant::Zero();
        ui64 Pending_ = 0;
    };
}

namespace NModRateLimiter {
    MODULE_BASE(rate_limiter, TModuleWithSubModule) {
    public:
        TModule(const TModuleParams& mp)
            : TModuleBase(mp)
        {
            ParseConfig();

            if (!Submodule_) {
                ythrow NConfig::TConfigParseError{} << "no submodule configured";
            }

            RateLimiter_ = MakeHolder<TRateLimiter>(Config_);
        }

    private:
        TError OnTooManyRequests(const TConnDescr& descr) const noexcept {
            TResponse resp = TResponseBuilder().Code(HTTP_TOO_MANY_REQUESTS).Version11();
            Y_TRY(TError, error) {
                Y_PROPAGATE_ERROR(descr.Output->SendHead(std::move(resp), false, TInstant::Max()));
                Y_PROPAGATE_ERROR(descr.Output->SendEof(TInstant::Max()));

                return SkipAll(descr.Input, TInstant::Max());
            } Y_CATCH {
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "client error");
                return error;
            };
            return {};
        }

        TError DoRun(const TConnDescr& descr) const noexcept override {
            switch (RateLimiter_->Take(descr.Process().Executor().Running())) {
                case ETakeResult::Ok:
                    return Submodule_->Run(descr);
                case ETakeResult::TooManyRequests:
                    descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "too many requests");
                    return OnTooManyRequests(descr);
                case ETakeResult::Canceled:
                    descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "cancelled");
                    return {};
            }
        }

        void ParseConfig() {
            Config_ = ParseProtoConfig<TModuleConfig>(
                [&](const TString& key, NConfig::IConfig::IValue* value) {
                    Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
                }
            );
        }

    private:
        TModuleConfig Config_;
        THolder<TRateLimiter> RateLimiter_;
    };

    IModuleHandle* Handle() {
        return TModule::Handle();
    }
}
