#include "module.h"

#include <balancer/modules/quota/quota.cfgproto.pb.h>

#include <balancer/kernel/http/parser/response_builder.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/rpslimiter/quota_manager.h>
#include <balancer/kernel/stats/manager.h>
#include <balancer/kernel/tvm/tvm.h>

#include <library/cpp/cgiparam/cgiparam.h>

#include <util/string/join.h>
#include <util/string/split.h>
#include <util/string/strip.h>
#include <util/string/subst.h>

#include <cmath>

using namespace NSrvKernel;
using namespace NRpsLimiter;
using namespace NModQuota;

namespace {
    struct TSharedStats {
        TSharedStats(const TString& uuid, TSharedStatsManager& statsManager)
            : Requests(statsManager.MakeCounter("quotas-" + uuid + "-Requests").AllowDuplicate().Build())
            // , AsyncRequests(statsManager.MakeCounter("quotas-" + uuid + "-AsyncRequests").AllowDuplicate().Build())
            , InvalidRequests(statsManager.MakeCounter("quotas-" + uuid + "-InvalidRequests").AllowDuplicate().Build())
            , LimitedRequests(statsManager.MakeCounter("quotas-" + uuid + "-LimitedRequests").AllowDuplicate().Build())
            , Consumed(statsManager.MakeCounter("quotas-" + uuid + "-Consumed").AllowDuplicate().Build())
            , Limited(statsManager.MakeCounter("quotas-" + uuid + "-Limited").AllowDuplicate().Build())
            , Limit(statsManager.MakeGauge("quotas-" + uuid + "-Limit")
                .AllowDuplicate().Aggregation(TWorkerAggregation::Max).Build())
            // , IntervalMs(statsManager.MakeGauge("quotas-" + uuid + "-Interval_ms")
            //    .AllowDuplicate().Aggregation(TWorkerAggregation::Max).Build())
        {}

        TSharedCounter Requests;
        // TSharedCounter AsyncRequests;
        TSharedCounter InvalidRequests;
        TSharedCounter LimitedRequests;
        TSharedCounter Consumed;
        TSharedCounter Limited;
        TSharedCounter Limit;
        // TSharedCounter IntervalMs;
    };

    struct TQuotaInfo {
        const TQuota* Quota;
        TMaybe<size_t> Idx;
        bool IsByQuota = false;
        TMaybe<TQuota> QuotaCopy;

        TQuotaInfo() {}
        TQuotaInfo(const TQuota* quota, TMaybe<size_t> idx, bool isByQuota = false,
                   TMaybe<TQuota> quotaCopy = Nothing())
            : Quota(quota)
            , Idx(idx)
            , IsByQuota(isByQuota)
            , QuotaCopy(quotaCopy)
        {
            if (QuotaCopy) {
                Quota = &*QuotaCopy;
            }
        }
        TQuotaInfo(const TQuotaInfo& other) {
            *this = other;
            if (QuotaCopy) {
                Quota = &*QuotaCopy;
            }
        }
        TQuotaInfo& operator=(const TQuotaInfo&) = default;
    };

    class TByQuotaValue {
    private:
         TString GetHeader(TStringBuf value) const {
            return TString(Descr.Request->Headers().GetFirstValue(value));
        }

        TString GetCgi(TStringBuf value) const {
            TString cgi(Descr.Request->RequestLine().CGI.AsStringBuf());
            if (!cgi) {
                return "";
            }
            if (cgi[0] == '?') {
                cgi = cgi.erase(0, 1);
            }
            TQuickCgiParam params(cgi);
            TStringBuf result = params.Get(value);
            return TString(result);
        }

        TString GetUrl() const {
            return TString(Descr.Request->RequestLine().Path.AsStringBuf());
        }

        TString GetCookie(TStringBuf value) const {
            TString cookieHeader = GetHeader("Cookie");
            if (!cookieHeader) {
                return "";
            }
            TVector<TString> cookies = StringSplitter(cookieHeader).Split(';');
            for (TString cookieString : cookies) {
                TVector<TString> cookie = StringSplitter(cookieString).Split('=').Limit(2);
                if (cookie.size() >= 2) {
                    TString cookieName = Strip(cookie[0]);
                    if (cookieName == value)  {
                        return Strip(cookie[1]);
                    }
                }
            }
            return "";
        }

        TString GetTvmId(TStringBuf value) const {
            auto ptr = Descr.Request->Headers().FindValues("x-ya-service-ticket");
            if (ptr != Descr.Request->Headers().end() && !ptr->second.empty()) {
                TStringBuf ticket = ptr->second[0].AsStringBuf();
                NTvmAuth::TTvmClient& TvmClient_ = TClientMap::Instance().GetClient(IntFromString<NTvmAuth::TTvmId, 10>(value));
                NTvmAuth::TCheckedServiceTicket checkedTicket = TvmClient_.CheckServiceTicket(ticket);
                if (checkedTicket) {
                    return IntToString<10>(checkedTicket.GetSrc());
                }
            }
            return "";
         }

    public:
        TByQuotaValue(const TByQuotaCfg& by, const TConnDescr& descr)
            : By(by)
            , Descr(descr)
        {}

        TString Get() const {
            TString result = "";
            if (By.data_type() == "headers") {
                result = GetHeader(By.value());
            } else if (By.data_type() == "cgi") {
                result = GetCgi(By.value());
            } else if (By.data_type() == "urls") {
                result = GetUrl();
            } else if (By.data_type() == "cookies") {
                result = GetCookie(By.value());
            } else if (By.data_type() == "tvm-service") {
                result = GetTvmId(By.value());
            }
            return result;
        }

    private:
        const TByQuotaCfg& By;
        const TConnDescr& Descr;
    };
}

Y_TLS(quota) {
    struct TStats {
        TStats(const TSharedStats& sharedStats, size_t workerId)
            : Y_BALANCER_INIT_WORKER_COUNTER(Requests, sharedStats, workerId)
            // , Y_BALANCER_INIT_WORKER_COUNTER(AsyncRequests, sharedStats, workerId)
            , Y_BALANCER_INIT_WORKER_COUNTER(InvalidRequests, sharedStats, workerId)
            , Y_BALANCER_INIT_WORKER_COUNTER(LimitedRequests, sharedStats, workerId)
            , Y_BALANCER_INIT_WORKER_COUNTER(Consumed, sharedStats, workerId)
            , Y_BALANCER_INIT_WORKER_COUNTER(Limited, sharedStats, workerId)
            , Y_BALANCER_INIT_WORKER_COUNTER(Limit, sharedStats, workerId)
            // , Y_BALANCER_INIT_WORKER_COUNTER(IntervalMs, sharedStats, workerId)
        {}

        TSharedCounter Requests;
        // TSharedCounter AsyncRequests;
        TSharedCounter InvalidRequests;
        TSharedCounter LimitedRequests;
        TSharedCounter Consumed;
        TSharedCounter Limited;
        TSharedCounter Limit;
        // TSharedCounter IntervalMs;
    };

    THolder<TStats> Stats;
};

MODULE_WITH_TLS(quota) {
public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
    {
        Config_ = ParseProtoConfig<TModuleConfig>([](const auto&, const TString& key, NConfig::IConfig::IValue*) {
            ythrow TConfigParseError() << "unknown key " << key.Quote();
        });

        QuotaManager_ = mp.Control->GetQuotaManager();

        Y_ENSURE_EX(QuotaManager_,
            TConfigParseError() << "no rpslimiter_instance found");

        const TVector<TString> quotas = ParseQuotasFromName(Config_.name());
        for (const TString& quota : quotas) {
            TMaybe<size_t> quotaIdx = QuotaManager_->Storage.QuotaIdx(quota);
            QuotaIndices_.push_back(quotaIdx);
            Y_ENSURE_EX(quotaIdx,
                TConfigParseError() << "quota " << quota.Quote() << " not found");

        }

        TString name = Config_.name();
        SubstGlobal(name, ",", "/");
        SharedStats_ = MakeHolder<TSharedStats>(name, mp.Control->SharedStatsManager());
    }

private:
    TVector<TString> ParseQuotasFromName(TStringBuf name) const {
        return StringSplitter(name).Split(',');
    }

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        auto tls = MakeHolder<TTls>();
        tls->Stats = MakeHolder<TTls::TStats>(*SharedStats_, process->WorkerId());
        const auto quotaIdx = QuotaIndices_[0];  // use only the first one when there are multiple quotas
        const auto& quota = QuotaManager_->Storage.QuotaInfo(*quotaIdx);
        tls->Stats->Limit.Set(quota.Limit);
        // tls->Stats->IntervalMs.Set(quota.Window.MilliSeconds());
        return tls;
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

    TQuotaInfo GetQuotaInfo(TMaybe<size_t> quotaIdx, const TConnDescr& descr, bool& empty) const noexcept {
        const auto& quota = QuotaManager_->Storage.QuotaInfo(*quotaIdx);
        if (!quota.By) {
            return TQuotaInfo(&quota, quotaIdx);
        }
        TString byValue = "";
        for (const auto& by : quota.By) {
            TByQuotaValue value(by, descr);
            TString current = value.Get();
            if (byValue) {
                byValue += "-";
            }
            byValue += current;
            empty |= !current && !by.allow_empty();
        }
        const auto byQuota = QuotaManager_->Storage.ByQuotaInfo(quota, byValue);
        return TQuotaInfo(nullptr, quotaIdx, true, byQuota);
    }

    TVector<TQuotaInfo> GetQuotasInfo(const TConnDescr& descr, bool& empty) const noexcept {
        TVector<TQuotaInfo> result;
        for (const auto quotaIdx : QuotaIndices_) {
            result.emplace_back(GetQuotaInfo(quotaIdx, descr, empty));
        }
        return result;
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const noexcept override {
        const auto now = descr.Properties->Start;
        descr.ExtraAccessLog << " name:" << Config_.name();

        TResponse response;
        TChunkList output;

        const auto& headers = descr.Request->Headers();
        const bool asyncRequest = ("1" == headers.GetFirstValue("x-yandex-rpslimiter-async-request"));
        const bool printState = ("1" == headers.GetFirstValue("x-yandex-rpslimiter-print-state"));
        const bool logQuota = ("1" == headers.GetFirstValue("x-yandex-rpslimiter-log-quota"));

        ui32 incr = 1;
        const auto customValue = headers.GetFirstValue("x-yandex-rpslimiter-custom-value");
        const bool invalidCustomValue = customValue && !TryFromString(customValue, incr);

        tls.Stats->Requests.Add(1);
        /*
        if (asyncRequest) {
            tls.Stats->AsyncRequests.Add(1);
        }
        */

        auto& storage = QuotaManager_->Storage;
        THashMap<TString, TVector<TString>> quotaHeaders;
        TVector<bool> quotaAllow;

        bool forbiddenEmptyByQuotas = false;
        const auto quotas = GetQuotasInfo(descr, forbiddenEmptyByQuotas);
        TString byQuotaString = "";
        for (const auto& quotaInfo : quotas) {
            if (quotaInfo.IsByQuota) {
                byQuotaString += quotaInfo.Quota->Name;
            }
        }

        if (logQuota || printState) {
            response.Headers().Add("x-yandex-rpslimiter-matched-quota", Config_.name());
            if (byQuotaString) {
                response.Headers().Add("x-yandex-rpslimiter-by-quota", byQuotaString);
            }
        }

        for (const auto& quotaInfo : quotas) {
            const auto quotaIdx = quotaInfo.Idx;
            const auto& quota = *quotaInfo.Quota;
            double rate = 0;
            if (!quotaInfo.IsByQuota) {
                rate = storage.TotalQuotaRate(*quotaIdx, now);
            } else {
                rate = storage.TotalByQuotaRate(*quotaInfo.Quota, now);
            }
            const bool allow = (rate + incr) <= quota.Limit || asyncRequest;
            quotaAllow.push_back(allow);

            if (printState) {
                quotaHeaders["x-yandex-rpslimiter-quota-limit"].push_back(ToString(quota.Limit));
                quotaHeaders["x-yandex-rpslimiter-quota-interval-ms"].push_back(ToString(quota.Window.MilliSeconds()));
                quotaHeaders["x-yandex-rpslimiter-current-value"].push_back(ToString(rate + incr));
            }
        }

        for (const auto& header : quotaHeaders) {
            const TString value = JoinSeq(",", header.second);
            response.Headers().Add(header.first, value);
        }

        const auto id = [](bool x) { return x; };
        const bool allow = AllOf(quotaAllow.begin(), quotaAllow.end(), id) && !forbiddenEmptyByQuotas;
        if (allow) {
            for (const auto& quotaInfo : quotas) {
                if (!quotaInfo.IsByQuota) {
                    storage.IncLocalQuota(*quotaInfo.Idx, now, incr);
                } else {
                    storage.IncLocalByQuota(*quotaInfo.Quota, now, incr);
                }
            }
            tls.Stats->Consumed.Add(incr);

            const auto message = "consumed:" + ToString(incr);
            descr.ExtraAccessLog << ' ' << message;
            response.ResponseLine().StatusCode = 200;
            output = TChunkList("ok");
            descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "allow");
        } else if (invalidCustomValue) {
            tls.Stats->InvalidRequests.Add(1);

            const auto message = "invalid custom value " + TString(customValue).Quote();
            descr.ExtraAccessLog << ' ' << message;
            response.ResponseLine().StatusCode = 400;
            output = TChunkList(message);
            descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "invalid custom value");
        } else {
            tls.Stats->LimitedRequests.Add(1);
            tls.Stats->Limited.Add(incr);

            const auto message = "limited:" + ToString(incr);
            descr.ExtraAccessLog << ' ' << message;
            response.ResponseLine().StatusCode = 429;
            response.Headers().Add("x-forwardtouser-y", "1");
            output = TChunkList("limited");
            descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "limited");
        }

        Y_TRY(TError, error) {
            Y_PROPAGATE_ERROR(descr.Output->SendHead(std::move(response), false, TInstant::Max()));
            Y_PROPAGATE_ERROR(descr.Output->Send(std::move(output), TInstant::Max()));
            Y_PROPAGATE_ERROR(descr.Output->SendEof(TInstant::Max()));
            return SkipAll(descr.Input, TInstant::Max());
        } Y_CATCH {
            descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "client error");
            return error;
        };
        return {};
    }

private:
    TModuleConfig Config_;
    TVector<TMaybe<size_t>> QuotaIndices_;
    TQuotaManager* QuotaManager_ = nullptr;
    THolder<TSharedStats> SharedStats_;
};

IModuleHandle* NModQuota::Handle() {
    return TModule::Handle();
}
