#include <balancer/serval/core/config.h>
#include <balancer/serval/mod/log_headers/log_headers.ev.pb.h>

namespace {
    static const TStringBuf NAME = "log_headers";
}

static NSv::TAction LogHeaders(const YAML::Node& node, NSv::TAuxData&) {

    class TLogHeadersStream : public NSv::TStreamProxy {
    public:
        TLogHeadersStream(NSv::IStreamPtr s, const TVector<TString>& responseHeaders)
            : NSv::TStreamProxy(std::move(s))
            , ResponseHeaders_(responseHeaders)
        {}

        bool WriteHead(NSv::THead& head) noexcept override {
            for (const auto& header : ResponseHeaders_) {
                auto range = head.equal_range(header);
                for (auto ptr = range.first; ptr != range.second; ++ptr) {
                    Log().Push<NSv::NEv::TResponseHeader>(TString(ptr->first), TString(ptr->second));
                }
            }
            return NSv::TStreamProxy::WriteHead(head);
        }

    private:
        const TVector<TString>& ResponseHeaders_;
    };

    TVector<TString> requestHeaders, responseHeaders;
    for (auto ptr = node[NAME].begin(); ptr != node[NAME].end(); ++ptr) {
        if (ptr->first.as<TStringBuf>() == "request_headers") {
            CHECK_NODE(*ptr, ptr->second.IsSequence(), "request_headers must be a sequence");
            for (auto header : ptr->second) {
                requestHeaders.push_back(header.as<TString>());
            }
        } else if (ptr->first.as<TStringBuf>() == "response_headers") {
            CHECK_NODE(*ptr, ptr->second.IsSequence(), "response_headers must be a sequence");
            for (auto header : ptr->second) {
                responseHeaders.push_back(header.as<TString>());
            }
        }
    }

    return [
        requestHeaders = std::move(requestHeaders),
        responseHeaders = std::move(responseHeaders)
    ] (NSv::IStreamPtr& req) {
        auto* head = req->Head();
        if (!head) {
            return false;
        }

        for (const auto& header : requestHeaders) {
            auto range = head->equal_range(header);
            for (auto ptr = range.first; ptr != range.second; ++ptr) {
                req->Log().Push<NSv::NEv::TRequestHeader>(TString(ptr->first), TString(ptr->second));
            }
        }

        req = std::make_shared<TLogHeadersStream>(req, responseHeaders);
        return true;
    };
}

SV_DEFINE_ACTION(NAME, LogHeaders);
