#include "module.h"
#include "enums.h"

#include <balancer/modules/rewrite/rewrite.cfgproto.pb.h>

#include <balancer/kernel/http/parser/header_validation.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/regexp/regexp_re2.h>
#include <balancer/kernel/custom_io/stream.h>

#include <library/cpp/uri/uri.h>

using namespace NSrvKernel;
using namespace NModRewrite;

MODULE_BASE(rewrite, TModuleWithSubModule) {
private:
    struct TAction {
        TAction(const TActionConfig& config)
            : Config(config)
        {
            Y_ENSURE_EX(Config.regexp(),
                TConfigParseError() << "nonempty regexp must be set");

            Regexp_ = MakeHolder<TRegexp>(
                Config.regexp(),
                TRegexp::TOpts()
                    .SetCaseInsensitive(Config.case_insensitive())
                    .SetLiteral(Config.literal())
            );

            Y_ENSURE_EX(Config.has_rewrite(),
                TConfigParseError() << "no rewrite pattern configured");

            if (Config.has_header_name()) {
                Y_ENSURE_EX(!Config.has_split(),
                    TConfigParseError() << "part splitting is only available for request line");
                Y_ENSURE_EX(CheckHeaderName(Config.header_name()),
                    TConfigParseError() << "bad \"header_name\" in \"rewrite\": " << Config.header_name().Quote());
                Y_ENSURE_EX(CheckRestrictedHeaderName(Config.header_name()),
                            TConfigParseError() << "\"header_name\" value " << Config.header_name().Quote()
                            << " contains one of the restricted headers " << RestrictedHeadersListString());
            }
        }

        TString Rewrite(TStringBuf value, const TRegexp::IRewriteSpec& spec) const noexcept {
            TString tmp;
            Regexp_->Rewrite(Config.rewrite(), '%', value, &tmp, Config.global(), spec);
            return tmp;
        }

    public:
        TActionConfig Config;
        THolder<TRegexp> Regexp_;
    };

public:
    TModule(const TModuleParams& moduleParams)
        : TModuleBase(moduleParams)
    {
        Config_ = ParseProtoConfig<TModuleConfig>(
            [&](const TString& key, NConfig::IConfig::IValue* value) {
                Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
            }
        );

        Y_ENSURE_EX(Submodule_,
            TConfigParseError() << "no submodule configured");
        Y_ENSURE_EX(Config_.actions(),
            TConfigParseError() << "empty action list");

        for (const auto& action : Config_.actions()) {
            if (action.has_header_name()) {
                HeadersActions_.emplace_back(action);
            } else {
                RequestActions_.emplace_back(action);
            }
        }

        StableSort(HeadersActions_, [](const TAction& l, const TAction& r) {
            return AsciiCompareIgnoreCase(l.Config.header_name(), r.Config.header_name()) < 0;
        });
    }

private:
    class TRewriteSpec : public TRegexp::IRewriteSpec {
    public:
        TRewriteSpec(const TConnDescr& descr) noexcept
            : Descr_(descr)
        {}

    private:
        void Custom(TStringBuf key, TString& value) const noexcept override {
            value.clear();

            if (key == "{url}"sv) {
                value = Descr_.Request->RequestLine().GetURL();
            } else if (key == "{host}"sv) {
                value = Descr_.Request->Headers().GetFirstValue("host");
            } else if (key == "{scheme}"sv) {
                if (Descr_.Properties->UserConnIsSsl) {
                    value = TStringBuf("https");
                } else {
                    value = TStringBuf("http");
                }
            } else {
                value = key;
            }
        }

    private:
        const TConnDescr& Descr_;
    };

private:
    TError DoRun(const TConnDescr& descr) const noexcept override {
        TRequest newRequest = *descr.Request;
        TRewriteSpec rewriteSpec{ descr };

        for (const TAction& action: RequestActions_) {
            switch (action.Config.split()) {
                case ERequestPart::Url: {
                    if (TError error = newRequest.RequestLine().SetURL(
                        action.Rewrite(newRequest.RequestLine().GetURL(), rewriteSpec)
                    )) {
                        descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "url rewrite error");
                        return error;
                    }
                } break;
                case ERequestPart::Path:
                    newRequest.RequestLine().Path = TStringStorage(
                        action.Rewrite(newRequest.RequestLine().Path.AsStringBuf(), rewriteSpec));
                    break;
                case ERequestPart::Cgi:
                    newRequest.RequestLine().CGI = TStringStorage(
                        action.Rewrite(newRequest.RequestLine().CGI.AsStringBuf(), rewriteSpec));
                    break;
                case ERequestPart::NormalizedPath:
                    NUri::TUri uri;
                    NUri::TState::EParsed result = uri.Parse(newRequest.RequestLine().Path.AsStringBuf());
                    if (result != NUri::TState::ParsedOK) {
                        descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "uri parse error");
                        return Y_MAKE_ERROR(THttpParseError(HTTP_BAD_REQUEST) << "failed to parse path");
                    }
                    TString rewrite = action.Rewrite(uri.GetField(NUri::TField::FieldPath), rewriteSpec);
                    NUri::TUriUpdate update(uri);
                    update.Set(NUri::TField::FieldPath, rewrite);
                    TString path;
                    uri.Print(path);
                    newRequest.RequestLine().Path = TStringStorage(path);
                    break;
            }
        }

        auto& headers = newRequest.Headers();
        for (auto& action : HeadersActions_) {
            if (auto it = headers.FindValues(action.Config.header_name()); it != headers.end()) {
                for (auto& val : it->second) {
                    val = TStringStorage(
                        action.Rewrite(val.AsStringBuf(), rewriteSpec));
                }
            }
        }

        TConnDescr newDescr = descr.Copy(&newRequest);
        return Submodule_->Run(newDescr);
    }

private:
    TModuleConfig Config_;

    TVector<TAction> RequestActions_;
    TVector<TAction> HeadersActions_;
};

IModuleHandle* NModRewrite::Handle() {
    return TModule::Handle();
}
