#include "module.h"
#include "handler.h"

#include <balancer/kernel/custom_io/stream.h>
#include <balancer/kernel/helpers/cast.h>
#include <balancer/kernel/helpers/common_parsers.h>
#include <balancer/kernel/http/parser/header_validation.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/memory/chunks.h>
#include <balancer/kernel/module/module.h>

using namespace NConfig;
using namespace NSrvKernel;
using namespace NHeadersForwarder;

MODULE_BASE(headers_forwarder, TModuleWithSubModule) {
private:
    class TAction: public IConfig::IFunc {
    public:
        explicit TAction(IConfig* config)
        {
            config->ForEach(this);

            Y_ENSURE_EX(RequestHeader, TConfigParseError() << "no request header configured");
            Y_ENSURE_EX(ResponseHeader, TConfigParseError() << "no response header configured");
            RequestHeaderLowerCase = to_lower(*RequestHeader);
            ResponseHeaderLowerCase = to_lower(*ResponseHeader);
            Y_ENSURE_EX((EraseFromResponse + Weak) < 2,
                        TConfigParseError() << "erase_from_response and weak are mutually excluding");
        }

        START_PARSE
            ON_KEY("request_header", RequestHeader) {
                return;
            }

            ON_KEY("response_header", ResponseHeader) {
                return;
            }

            ON_KEY("erase_from_request", EraseFromRequest) {
                return;
            }

            ON_KEY("erase_from_response", EraseFromResponse) {
                return;
            }

            ON_KEY("weak", Weak) {
                return;
            }
        END_PARSE

        TMaybe<TString> RequestHeader;
        TMaybe<TString> ResponseHeader;
        TString RequestHeaderLowerCase;
        TString ResponseHeaderLowerCase;
        bool EraseFromRequest = false;
        bool EraseFromResponse = false;
        bool Weak = false;
    };

public:
    explicit TModule(const TModuleParams& moduleParams)
        : TModuleBase(moduleParams)
    {
        Config->ForEach(this);

        Y_ENSURE_EX(Submodule_, TConfigParseError() << "no submodule configured");
    }

private:
    START_PARSE
        if (key == "actions") {
            ParseMap(value->AsSubConfig(), [&](const TString&, IConfig::IValue* val) {
                Actions_.emplace_back(val->AsSubConfig());
            });

            Y_ENSURE_EX(Actions_, TConfigParseError() << "empty action list");

            for (const TAction& action: Actions_) {
                const TString& key = action.RequestHeaderLowerCase;
                ActionByRequestHeader_[key].push_back(&action);
                if (action.EraseFromRequest) {
                    PreHandler_.Delete(key);
                }
            }
            return;
        }

        Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
        return;
    END_PARSE

    TSimpleHandler CreatePostHandler(const TConnDescr& descr) const noexcept {
        TSimpleHandler postHandler;

        for (const TAction& action : Actions_) {
            if (action.EraseFromResponse) {
                postHandler.Delete(action.ResponseHeaderLowerCase);
            }
        }

        for (const auto& header : descr.Request->Headers()) {
            TString headerLowerCase = TString(header.first.AsStringBuf());
            headerLowerCase.to_lower();

            if (auto it = ActionByRequestHeader_.find(headerLowerCase); it != ActionByRequestHeader_.end()) {
                for (const auto& headerValue : header.second) {
                    TString value = ToString(headerValue);
                    for (const auto& action: it->second) {
                        postHandler.Create(action->ResponseHeaderLowerCase, *action->ResponseHeader, value, action->Weak);
                    }
                }
            }
        }

        return postHandler;
    }

    TError DoRun(const TConnDescr& descr) const noexcept override {
        if (descr.Request) {
            auto request = *descr.Request;
            PreHandler_.ApplyRequest(request);

            auto postHandler = CreatePostHandler(descr);
            auto output = MakeHttpOutput([&](TResponse&& response, const bool forceClose, TInstant deadline) {
                postHandler.ApplyResponse(response);
                return descr.Output->SendHead(std::move(response), forceClose, deadline);
            }, [&](TChunkList lst, TInstant deadline) {
                return descr.Output->Send(std::move(lst), deadline);
            }, [&](THeaders&& trailers, TInstant deadline) {
                return descr.Output->SendTrailers(std::move(trailers), deadline);
            });

            auto newDescr = descr.CopyOut(output);
            newDescr.Request = &request;
            Y_PROPAGATE_ERROR(Submodule_->Run(newDescr));
        } else {
            Y_PROPAGATE_ERROR(Submodule_->Run(descr));
        }
        return {};
    }

private:
    TDeque<TAction> Actions_;
    THashMap<TString, TVector<const TAction*> > ActionByRequestHeader_; // non-owning
    TSimpleHandler PreHandler_;
};

IModuleHandle* NModHeadersForwarder::Handle() {
    return Nheaders_forwarder::TModule::Handle();
}
