#include "handler.h"
#include "module.h"

#include <balancer/kernel/http/parser/header_validation.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/module/module_requirements.h>
#include <balancer/kernel/custom_io/stream.h>
#include <balancer/kernel/net/sockops.h>

#include <util/string/cast.h>

using namespace NConfig;
using namespace NSrvKernel;
using namespace NModHeaders;

namespace NModHeaders {

namespace {

bool ParseHandlerRule(THeadersHandler& handler, const TString& delimiter, const TString& name,
                      const TString& key, NConfig::IConfig::IValue* value,
                      TModuleRequirements* requirements) {
    class TParser: public IConfig::IFunc {
    public:
        TParser(
            THeadersHandler& handler,
            const TString& delimiter,
            const TString& name,
            void (TParser::*func)(const TString &, const TString &),
            bool append,
            bool weak,
            NConfig::IConfig* config,
            TModuleRequirements* requirements = nullptr
        )
            : Handler_(handler)
            , Delimiter_(delimiter)
            , Name_(name)
            , Func_(func)
            , Append_(append)
            , Weak_(weak)
            , Requirements_(requirements)
        {
            config->ForEach(this);
        }

        void CopyValue(const TString& key, const TString& value) {
            if (value.empty()) {
                ythrow TConfigParseError() << "empty header value";
            } else if (!CheckHeaderValue(value)) {
                ythrow TConfigParseError() << "bad header value: " << "\"" << value << "\"";
            } else {
                Handler_.CopyValue(key, value, Weak_);
            }
        }

        void ParseValue(const TString& key, const TString& value) {
            Y_ENSURE_EX(value,
                        TConfigParseError() << "empty header value");

            Y_ENSURE_EX(CheckHeaderValue(value),
                        TConfigParseError() << "bad header value: " << value.Quote());

            TCreateHeader ch{key, value, {}, Delimiter_};
            if (Append_) {
                Handler_.Append(ch, Weak_);
            } else {
                Handler_.Create(ch, Weak_);
            }
        }

        void ParseFunc(const TString& key, const TString& value) {
            if (value == ToString(EHeaderFunc::TcpInfo) && !CanGetTcpInfo()) {
                PrintOnce(TStringBuilder() << "WARNING in " << Name_
                                           << ": \"tcp_info\" is unavailable, ignoring it");
            } else {
                TCreateHeader ch = ParseHeaderFunc(key, value, Delimiter_);
                if (Append_) {
                    Handler_.Append(ch, Weak_);
                } else {
                    Handler_.Create(ch, Weak_);
                }
                if (Requirements_ && value == ToString(EHeaderFunc::P0f)) {
                    Requirements_->P0f = true;
                }
            }
        }

        void ParseFile(const TString& key, const TString& value) {
            Y_ENSURE_EX(value,
                        TConfigParseError() << "empty file name");
            //TODO: restrict file path?

            TCreateHeader ch{key, TValueFromFile{value}, {}, Delimiter_};
            Handler_.AddFile(TValueFromFile{value});
            if (Append_) {
                Handler_.Append(ch, Weak_);
            } else {
                Handler_.Create(ch, Weak_);
            }
        }

    private:
        START_PARSE {
            // TODO(velavokr): simplify
            if (key.empty()) {
                ythrow TConfigParseError() << "empty header name";
            } else if (!CheckHeaderName(key)) {
                ythrow TConfigParseError() << "bad header name: \"" << key << '"';
            } else if (!CheckRestrictedHeaderName(key)) {
                ythrow TConfigParseError{} << "header " << key.Quote()
                    << " contains one of the restricted headers " << RestrictedHeadersListString() << "\n";
            }

            (this->*Func_)(key, value->AsString());

            return;
        } END_PARSE

    private:
        THeadersHandler& Handler_;
        const TString& Delimiter_;
        const TString& Name_;
        void (TParser::*const Func_)(const TString&, const TString&);
        const bool Append_;
        const bool Weak_;
        TModuleRequirements* Requirements_ = nullptr;
    };

    if (key == "delete") {
        if (!handler.Delete(value->AsString())) {
            ythrow TConfigParseError() << "delete regexp " << value->AsString().Quote()
                << " contains one of the restricted headers " << RestrictedHeadersListString() << "\n";
        }
        return true;
    }

    if (key == "copy") {
        TParser(handler, delimiter, name, &TParser::CopyValue, false, false, value->AsSubConfig());
        return true;
    }

    if (key == "copy_weak") {
        TParser(handler, delimiter, name, &TParser::CopyValue, false, true, value->AsSubConfig());
        return true;
    }

    if (key == "create") {
        TParser(handler, delimiter, name, &TParser::ParseValue, false, false, value->AsSubConfig());
        return true;
    }

    if (key == "append") {
        TParser(handler, delimiter, name, &TParser::ParseValue, true, false, value->AsSubConfig());
        return true;
    }

    if (key == "create_weak") {
        TParser(handler, delimiter, name, &TParser::ParseValue, false, true, value->AsSubConfig());
        return true;
    }

    if (key == "append_weak") {
        TParser(handler, delimiter, name, &TParser::ParseValue, true, true, value->AsSubConfig());
        return true;
    }

    if (key == "create_func") {
        TParser(handler, delimiter, name, &TParser::ParseFunc, false, false, value->AsSubConfig(), requirements);
        return true;
    }

    if (key == "append_func") {
        TParser(handler, delimiter, name, &TParser::ParseFunc, true, false, value->AsSubConfig(), requirements);
        return true;
    }

    if (key == "create_func_weak") {
        TParser(handler, delimiter, name, &TParser::ParseFunc, false, true, value->AsSubConfig(), requirements);
        return true;
    }

    if (key == "append_func_weak") {
        TParser(handler, delimiter, name, &TParser::ParseFunc, true, true, value->AsSubConfig(), requirements);
        return true;
    }

    if (key == "create_from_file") {
        TParser(handler, delimiter, name, &TParser::ParseFile, false, false, value->AsSubConfig());
        return true;
    }

    if (key == "create_from_file_weak") {
        TParser(handler, delimiter, name, &TParser::ParseFile, false, true, value->AsSubConfig());
        return true;
    }

    return false;
}

}

class TTlsImpl: public NConfig::IConfig::IFunc {
public:
    explicit TTlsImpl(const TString& name)
        : Name_(name)
    {}

    THeadersHandler* Handler() noexcept {
        UpdateHandler();
        return Handler_.Get();
    }

    TSharedFileReReader RulesFileReader;
    TSharedFileExistsChecker DisableAltSvcFileChecker;
    bool Override_ = false;
    THashMap<TString, TSharedFileReReader> HeaderFiles;
private:
    void UpdateHandler() noexcept {
        const auto& data = RulesFileReader.Data();
        if (data.Id() != RulesData_.Id()) {
            RulesData_ = data;
            UpdateHandler(RulesData_.Data());
        }
    }

    void UpdateHandler(const TString& data) noexcept {
        try {
            TStringInput cfgin(data);
            THolder<IConfig> config = NConfig::ConfigParser(cfgin);
            // Order of params is not determined, so preparse 'delimiter' firstly
            NSrvKernel::ParseMap(config.Get(), [this](const auto &key, auto *value) {
                ON_KEY("delimiter", Delimiter_) {
                    return;
                }
            });
            Handler_ = THeadersHandler{};
            Override_ = false;
            config->ForEach(this);
        } catch(...) {
            Override_ = false;
            Handler_.Clear();
        }
    }

    START_PARSE {
        // Skip already parsed 'delimiter'
        if (key == "delimiter") {
            return;
        }

        ON_KEY("override", Override_) {
            return;
        }

        if (ParseHandlerRule(Handler_.GetRef(), Delimiter_, Name_, key, value, nullptr)) {
            return;
        }
    } END_PARSE

    TSharedFileReReader::TData RulesData_;
    TString Delimiter_;
    TMaybe<THeadersHandler> Handler_;
    const TString& Name_;
};

Y_TLS(headers) {
    explicit TTls(THolder<TTlsImpl> impl)
        : Impl_(std::move(impl))
    {}

    THolder<TTlsImpl> Impl_;
};

Y_TLS(response_headers) {
    explicit TTls(THolder<TTlsImpl> impl)
        : Impl_(std::move(impl))
    {}

    THolder<TTlsImpl> Impl_;
};

class TImpl: public TModuleParams, public NConfig::IConfig::IFunc {
public:
    TImpl(const TModuleParams& moduleParams, THolder<IModule>& submodule, TString name)
        : TModuleParams(moduleParams)
        , Submodule_(submodule)
        , Name_(std::move(name))
    {
        // Order of params is not determined, so preparse 'delimiter' firstly
        NSrvKernel::ParseMap(moduleParams.Config, [this](const auto &key, auto *value) {
            ON_KEY("delimiter", Delimiter_) {
                return;
            }
        });

        Config->ForEach(this);

        if (!Submodule_) {
            ythrow TConfigParseError() << "no submodule configured";
        }
    }

    THolder<TTlsImpl> InitTls(IWorkerCtl* process) {
        THolder<TTlsImpl> impl = MakeHolder<TTlsImpl>(Name_);
        if (RulesFileName_) {
            impl->RulesFileReader = process->SharedFiles()->FileReReader(RulesFileName_, TDuration::Seconds(1));
        }
        Handler_.SetupFileReaders(impl->HeaderFiles, process);
        return impl;
    }

    void ApplyHandlers(TTlsImpl& impl, NSrvKernel::THeaders& headers, const NSrvKernel::TConnDescr& descr) const noexcept {
        THeadersHandler* handler = impl.Handler();
        if (handler) {
            handler->Apply(headers, descr, impl.HeaderFiles);
            if (impl.Override_) {
                return;
            }
        }
        Handler_.Apply(headers, descr, impl.HeaderFiles);
    }
private:
    START_PARSE {
        // Skip already parsed 'delimiter'
        if (key == "delimiter") {
            return;
        }

        ON_KEY("rules_file", RulesFileName_) {
            return;
        }

        // TODO(smalukav): Remove MINOTAUR-2915
        ON_KEY("fix_http3_headers", FixHttp3Headers_) {
            return;
        }

        TModuleRequirements* requirements = Control ? &Control->GetModuleRequirements() : nullptr;
        if (ParseHandlerRule(Handler_, Delimiter_, Name_, key, value, requirements)) {
            return;
        }

        {
            Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
            return;
        }
    } END_PARSE

private:
    THeadersHandler Handler_;
    THolder<IModule>& Submodule_;
    TString Name_;
    TString Delimiter_;
    TString RulesFileName_;

public:
    bool FixHttp3Headers_ = false;
};

MODULE_WITH_TLS_BASE(headers, TModuleWithSubModule) {
    TModule(const TModuleParams& moduleParams)
        : TModuleBase(moduleParams)
        , Impl_(moduleParams, Submodule_, NAME)
    {}

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        return MakeHolder<TTls>(Impl_.InitTls(process));
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const override {
        Y_ASSERT(descr.Request);

        // TODO(smalukav): Remove MINOTAUR-2915
        if (Impl_.FixHttp3Headers_) {
            for (TStringBuf header : {
                "X-Forwarded-Proto", "X-Forwarded-For-Y", "X-Forwarded-For", "X-Yandex-IP", "X-Source-Port-Y",
                "X-Yandex-HTTP-Version", "X-Yandex-HTTPS", "X-HTTPS-Request", "X-Yandex-Family-Search"
            }) {
                if (auto v = descr.Request->Headers().GetValuesMoveNonOwned(header)) {
                    descr.Request->Headers().Delete(header);
                    descr.Request->Headers().Add(header, std::move(v));
                }
            }
        }

        Impl_.ApplyHandlers(*tls.Impl_, descr.Request->Headers(), descr);
        return Submodule_->Run(descr);
    }
private:
    TImpl Impl_;
};

MODULE_WITH_TLS_BASE(response_headers, TModuleWithSubModule) {
    TModule(const TModuleParams& moduleParams)
        : TModuleBase(moduleParams)
        , Impl_(moduleParams, Submodule_, NAME)
    {}

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        return MakeHolder<TTls>(Impl_.InitTls(process));
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const override {
        auto output = MakeHttpOutput([&](TResponse&& response, const bool forceClose, TInstant deadline) {
            Impl_.ApplyHandlers(*tls.Impl_, response.Headers(), descr);
            return descr.Output->SendHead(std::move(response), forceClose, deadline);
        }, [&](TChunkList lst, TInstant deadline) {
            return descr.Output->Send(std::move(lst), deadline);
        }, [&](THeaders&& trailers, TInstant deadline) {
            return descr.Output->SendTrailers(std::move(trailers), deadline);
        });
        return Submodule_->Run(descr.CopyOut(output));
    }
private:
    TImpl Impl_;
};

}

IModuleHandle* NModHeaders::Handle() {
    return Nheaders::TModule::Handle();
}

IModuleHandle* NModResponseHeaders::Handle() {
    return Nresponse_headers::TModule::Handle();
}
