#include "module.h"

#include <balancer/kernel/custom_io/stream.h>
#include <balancer/kernel/fs/shared_file_exists_checker.h>
#include <balancer/kernel/helpers/default_instance.h>
#include <balancer/kernel/helpers/misc.h>
#include <balancer/kernel/http/parser/header_validation.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/log/errorlog.h>
#include <balancer/kernel/module/iface.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/regexp/regexp_pire.h>
#include <balancer/kernel/requester/requester.h>
#include <balancer/kernel/stats/manager.h>
#include <balancer/modules/exp_common/exp_common.h>

#include <library/cpp/regex/pire/regexp.h>

#include <util/generic/singleton.h>
#include <util/generic/yexception.h>
#include <util/network/address.h>
#include <util/string/cast.h>
#include <util/thread/singleton.h>

#include <utility>


using namespace NConfig;
using namespace NSrvKernel;
using namespace NRegExp;

namespace {
    TString BuildCounterName(const TString& serviceName) {
        if (serviceName) {
            return TStringBuilder() << "exp-service-" << serviceName << "-limited_headers";
        } else {
            return "exp-default-limited_headers";
        }
    }
}

// TODO:
// 1. How to send ip addr?
// 2. How to send service name?
// 3. Necessary headers
//
// Add service name header
// If has service name headers, for uaas request (and only for it):
// 1. Delete this header from original request
// 2. Insert the new header


struct TRestrictedHeadersFsm : public TFsm , public TWithDefaultInstance<TRestrictedHeadersFsm> {
    TRestrictedHeadersFsm()
        : TFsm("content-length|transfer-encoding|connection", TFsm::TOptions().SetCaseInsensitive(true))
    {}
};

Y_TLS(exp_getter) {
    bool KillSwitchFileExists() const noexcept {
        return KillSwitchChecker.Exists();
    }

    TSharedFileExistsChecker KillSwitchChecker;

    size_t UaasNoAnswer = 0;
    size_t UaasNoHeader = 0;
    size_t TrustedRequests = 0;

    TSharedCounter ServiceCounter_;
    TSharedCounter TotalCounter_;
};

MODULE_WITH_TLS_BASE(exp_getter, TModuleWithSubModule) {
public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
    {
        Config->ForEach(this);

        if (!UaasModule_) {
            ythrow TConfigParseError() << "no uaas module configured for exp_getter";
        }

        if (!Submodule_) {
            ythrow TConfigParseError() << "no module configured for exp_getter";
        }

        if (ExpHeadersFsm_ == nullptr) {
            ExpHeadersFsm_ = &NExpCommon::ExpHeaders();
        }

        if (ServiceName_) {
            if (!ServiceNameHeader_) {
                ythrow TConfigParseError() << "service_name_header is required if service_name is set";
            }
            UberFsm_.Reset(
                    new TFsm(NExpCommon::ExpHeaders()            // 0
                      | TRestrictedHeadersFsm::Instance()   // 1
                      | TFsm(ServiceNameHeader_, TFsm::TOptions().SetCaseInsensitive(true)) // 2
            ));
        } else {
            UberFsm_.Reset(
                    new TFsm(NExpCommon::ExpHeaders()            // 0
                      | TRestrictedHeadersFsm::Instance()   // 1
            ));
        }

        ServiceCounter_ = Control->SharedStatsManager().MakeCounter(BuildCounterName(ServiceName_)).AllowDuplicate().Build();
        TotalCounter_ = Control->SharedStatsManager().MakeCounter("exp-total-limited_headers").AllowDuplicate().Build();
    }


private:
    START_PARSE {
        PARSE_EVENTS;

        if (key == "uaas") {
            TSubLoader(Copy(value->AsSubConfig())).Swap(UaasModule_);
            return;
        }

        ON_KEY("file_switch", KillSwitchFile_) {
            return;
        }

        ON_KEY("service_name_to_backend_header", ServiceNameHeaderToBackend_) {
            if (!CheckHeaderName(ServiceNameHeaderToBackend_)) {
                ythrow TConfigParseError{} << "\"service_name_header_to_backend\" value is not a valid http header";
            } else if (!CheckRestrictedHeaderName(ServiceNameHeaderToBackend_)) {
                ythrow TConfigParseError{} << "\"service_name_header_to_backend\" value " <<  ServiceNameHeaderToBackend_.Quote()
                 << " contains one of the restricted headers " << RestrictedHeadersListString() << "\n";
            }
            return;
        }

        ON_KEY("service_name_header", ServiceNameHeader_) {
            if (!CheckHeaderName(ServiceNameHeader_)) {
                ythrow TConfigParseError{} << "\"service_name_header\" value is not a valid http header";
            } else if (!CheckRestrictedHeaderName(ServiceNameHeader_)) {
                ythrow TConfigParseError{} << "\"service_name_header\" value " <<  ServiceNameHeader_.Quote()
                    << " contains one of the restricted headers " << RestrictedHeadersListString() << "\n";
            }
            return;
        }

        ON_KEY("service_name", ServiceName_) {
            return;
        }

        TString expHeaders;
        ON_KEY("exp_headers", expHeaders) {
            ExpHeaders_.Reset(new TFsm(expHeaders, TFsm::TOptions().SetCaseInsensitive(true)));
            ExpHeadersFsm_ = ExpHeaders_.Get();
            return;
        };

        ON_KEY("trusted", MayTrust_) {
            return;
        }

        ON_KEY("processing_time_header", ProcessingTimeHeader_) {
            return;
        }

        ON_KEY("headers_size_limit", ExtraHeadersSizeLimit_) {
            return;
        }

        Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
        return;
    } END_PARSE

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        auto tls = MakeHolder<TTls>();
        tls->ServiceCounter_ = TSharedCounter(ServiceCounter_, process->WorkerId());
        tls->TotalCounter_ = TSharedCounter(TotalCounter_, process->WorkerId());
        if (!!KillSwitchFile_) {
            tls->KillSwitchChecker = process->SharedFiles()->FileChecker(KillSwitchFile_, TDuration::Seconds(1));
        }
        return tls;
    }

    TError RefineRequest(const TConnDescr& descr, TTls& tls, THolder<TExtraAccessLogEntry>& uaasLog) const noexcept {
        if (!tls.KillSwitchFileExists()) {
            uaasLog = MakeHolder<TExtraAccessLogEntry>(descr, "uaas");
            if (IsTrusted(descr)) {
                ++tls.TrustedRequests;
                descr.ExtraAccessLog << " trusted";
            } else {
                bool succeed = false;

                Y_TRY(TError, error) {
                    Y_ASSERT(Submodule_);
                    Y_REQUIRE(descr.Request, yexception{} << "no parsed request for uaas");

                    descr.Request->Headers().Delete(NExpCommon::ExpHeaders());

                    // see USEREXP-7668
                    const auto method = descr.Request->RequestLine().Method;
                    if (method == EMethod::OPTIONS || method == EMethod::CONNECT) {
                        succeed = true;
                        descr.ExtraAccessLog << " method_not_supported";
                        return {};
                    }

                    TRequest request = ConstructRequest(descr);

                    TRequester requester(*UaasModule_, descr);
                    TResponse response;
                    Y_PROPAGATE_ERROR(requester.Request(std::move(request), response));

                    descr.Request->Headers().Add(NExpCommon::STAFF_HEADER_NAME,
                        response.Headers().GetValuesMove(NExpCommon::STAFF_HEADER_NAME)
                    );

                    succeed = CopyUaasHeaders(response.Headers(), descr.Request->Headers());
                    if (succeed) {
                        descr.ExtraAccessLog << " uaas_answered";
                    } else {
                        ++tls.UaasNoHeader;
                    }
                    return {};
                } Y_CATCH {
                    ++tls.UaasNoAnswer;
                    LOG_ERROR(TLOG_ERR, descr, "uaas failed: " << GetErrorMessage(error));
                }

                if (!succeed) {
                    descr.ExtraAccessLog << " uaas failed";

                    if (Y_LIKELY(descr.Request)) { // TODO: copypaste
                        descr.Request->Headers().Delete(NExpCommon::ExpHeaders());
                    }
                }
            }
        } else {
            if (Y_LIKELY(descr.Request)) { // TODO: copypaste, maybe not delete it for trusted
                descr.Request->Headers().Delete(NExpCommon::ExpHeaders());
            }
        }

        if (ServiceName_ && ServiceNameHeaderToBackend_) {
            descr.Request->Headers().Add(ServiceNameHeaderToBackend_, ServiceName_);
        }

        return {};
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const noexcept override {
        {
            THolder<TExtraAccessLogEntry> uaasLog;

            bool checkHeadersSize = ExtraHeadersSizeLimit_ > 0 && descr.Request;
            size_t before = checkHeadersSize ? descr.Request->EncodedSize() - descr.Request->RequestLine().EncodedSize() : 0;

            TDuration processingTime = TDuration::Max();
            Y_PROPAGATE_ERROR(MeasureProcessingTime([&]() { return RefineRequest(descr, tls, uaasLog); }, &processingTime));
            if (ProcessingTimeHeader_) {
                descr.Request->Headers().Add("X-Yandex-Balancer-ExpGetter-ProcessingTime",
                                                   ToString(processingTime.MicroSeconds()));
            }

            if (checkHeadersSize) {
                size_t after = descr.Request->EncodedSize() - descr.Request->RequestLine().EncodedSize();
                if (after > before + ExtraHeadersSizeLimit_) {
                    tls.ServiceCounter_.Inc();
                    tls.TotalCounter_.Inc();
                    if (uaasLog) {
                        descr.ExtraAccessLog << " limited_headers " << (after-before) << "/" << ExtraHeadersSizeLimit_;
                    }
                }
            }
        }

        return Submodule_->Run(descr);
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

private:
    bool IsTrusted(const TConnDescr& descr) const noexcept {
        return MayTrust_ && HasNecessaryHeaders(descr);
    }

    bool HasNecessaryHeaders(const TConnDescr& descr) const noexcept {
        if (Y_UNLIKELY(!descr.Request)) {
            return false;
        }

        for (const auto& hdr: descr.Request->Headers()) {
            if (Match(NExpCommon::ExpHeaders(), hdr.first.AsStringBuf())) {
                return true;
            }
        }

        return false;
    }

    bool CopyUaasHeaders(THeaders& from, THeaders& to) const noexcept {
        bool found = false;
        for (auto& hdr: from) {
            if (NSrvKernel::Match(*ExpHeadersFsm_, hdr.first.AsStringBuf())) {
                to.Add(hdr.first.AsString(), THeaders::MakeOwned(std::move(hdr.second)));
                found = true;
            }
        }

        return found;
    }

    [[nodiscard]] TRequest ConstructRequest(const TConnDescr& descr) const noexcept {
        TRequest request = *descr.Request;

        // TODO: copy selectively
        bool metService{ false };
        for (auto it = request.Headers().begin(); it != request.Headers().end();) {
            auto& header = *it;
            TMatcher matcher{ *UberFsm_ };

            if (Match(matcher, header.first.AsStringBuf()).Final()) {
                switch (*matcher.MatchedRegexps().first) {
                    case 0: // NExpCommon::ExpHeaders
                    case 1: // TRestrictedHeadersFsm
                        request.Headers().erase(it++);
                        continue;
                    case 2: // ServiceHeader
                        if (ServiceName_) {
                            if (!metService) {
                                for (auto& hdrValue : header.second) {
                                    auto temp = TStringStorage(ServiceName_);
                                    hdrValue.Swap(temp);
                                }
                                metService = true;
                            } else {
                                request.Headers().erase(it++);
                                continue;
                            }
                        }
                        break;
                }
            }
            ++it;
        }

        if (!metService && ServiceName_) {
            request.Headers().Add(ServiceNameHeader_, ServiceName_);
        }
        return request;
    }

private:
    THolder<IModule> UaasModule_;
    THolder<TFsm> UberFsm_;
    THolder<TFsm> ExpHeaders_;
    const TFsm* ExpHeadersFsm_{ nullptr };

    TString KillSwitchFile_;

    TString ServiceNameHeaderToBackend_;
    TString ServiceNameHeader_;
    TString ServiceName_;

    bool MayTrust_ = false;
    bool ProcessingTimeHeader_ = false;

    size_t ExtraHeadersSizeLimit_ = 8192;
    TSharedCounter ServiceCounter_;
    TSharedCounter TotalCounter_;
};

IModuleHandle* NModExpGetter::Handle() {
    return TModule::Handle();
}
