#include "module.h"

#include <balancer/kernel/http/parser/header_validation.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/memory/chunks.h>
#include <balancer/kernel/module/iface.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/regexp/regexp_pire.h>
#include <balancer/kernel/custom_io/stream.h>
#include <balancer/kernel/matcher/matcher.h>

#include <library/cpp/regex/pire/pire.h>

#include <util/generic/ptr.h>
#include <util/generic/string.h>
#include <util/generic/vector.h>

#include <utility>


using namespace NConfig;
using namespace NSrvKernel;
using namespace NModResponseHeadersIf;


MODULE_BASE(response_headers_if, TModuleWithSubModule) {
public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
    {
        Config->ForEach(this);

        if (!HeaderNameFsm_ && !Matcher_) {
            ythrow TConfigParseError() << "\"if_has_header\" or \"matcher\" parameter is required";
        }

        if (HeaderNameFsm_ && Matcher_) {
            ythrow TConfigParseError() << "only one of \"if_has_header\" and \"matcher\" should be defined";
        }

        if (Headers_.empty() && !DeleteFsm_) {
            ythrow TConfigParseError() << "\"create_header\" or \"delete_header\" parameter is required";
        }

        if (!Submodule_) {
            ythrow TConfigParseError() << "no submodule configured";
        }
    }

private:
    void AddHeader(const TString& name, const TString& value) {
        if (!CheckHeaderName(name)) {
            ythrow TConfigParseError() << "bad header name: " << name.Quote();
        } else if (!CheckRestrictedHeaderName(name)) {
            ythrow TConfigParseError{} << "header " << name.Quote()
                << " contains one of the restricted headers " << RestrictedHeadersListString() << "\n";
        }

        Headers_.emplace_back(std::make_pair(name, value));
    }

    void Process(TResponse& response) const noexcept {
        if (HeaderNameFsm_ && response.Headers().GetFirstValue(*HeaderNameFsm_) ||
            Matcher_ && Matcher_->Match(response))
        {
            THeaders& headers = response.Headers();
            if (HeaderNameFsm_ && EraseMatchingHeader_) {
                headers.Delete(*HeaderNameFsm_);
            }

            if (DeleteFsm_) {
                headers.Delete(*DeleteFsm_);
            }

            for (const auto& header: Headers_) {
                headers.Add(header.first, header.second);
            }
        }
    }

    START_PARSE {
        TString arg;
        ON_KEY("if_has_header", arg) {
            HeaderNameFsm_.Reset(new TFsm(arg, TFsm::TOptions().SetCaseInsensitive(true)));
            return;
        }

        ON_KEY("erase_if_has_header", EraseMatchingHeader_) {
            return;
        }

        if (key == "matcher") {
            ParseMap(value->AsSubConfig(), [this](const auto& key, auto* value) {
                if (auto matcher = ConstructResponseMatcher(nullptr, key, value->AsSubConfig())) {
                    Matcher_ = std::move(matcher);
                    return;
                } else {
                    ythrow TConfigParseError() << "cannot configure matcher '" << key << "'";
                }
            });
            return;
        }

        if (key == "create_header") {
            ParseMap(value->AsSubConfig(), [this](const auto& key, auto* value) {
                AddHeader(key, value->AsString());
            });
            return;
        }

        if (key == "delete_header") {
            DeleteFsm_.Reset(new TFsm(value->AsString(), TFsm::TOptions().SetCaseInsensitive(true).SetSurround(false)));
            if (!CheckRestrictedHeaderRegexp(*DeleteFsm_)) {
                ythrow TConfigParseError() << "delete regexp " << value->AsString().Quote()
                    << " contains one of the restricted headers " << RestrictedHeadersListString() << "\n";
            }
            return;
        }

        {
            Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
            return;
        }
    } END_PARSE

    TError DoRun(const TConnDescr& descr) const noexcept override {
        auto output = MakeHttpOutput([&](TResponse&& response, const bool forceClose, TInstant deadline) {
            Process(response);
            return descr.Output->SendHead(std::move(response), forceClose, deadline);
        }, [&](TChunkList lst, TInstant deadline) {
            return descr.Output->Send(std::move(lst), deadline);
        }, [&](THeaders&& trailers, TInstant deadline) {
            return descr.Output->SendTrailers(std::move(trailers), deadline);
        });
        return Submodule_->Run(descr.CopyOut(output));
    }

private:
    THolder<TFsm> HeaderNameFsm_;
    THolder<IResponseMatcher> Matcher_;
    THolder<TFsm> DeleteFsm_;
    TVector<std::pair<TString, TString>> Headers_;
    bool EraseMatchingHeader_{ false };
};

IModuleHandle* NModResponseHeadersIf::Handle() {
    return TModule::Handle();
}
