#include "module.h"
#include "redirects.h"
#include "url_parts.h"

#include <balancer/modules/redirects/redirects.cfgproto.pb.h>

#include <balancer/kernel/module/module.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/http/parser/response_builder.h>

#include <util/generic/overloaded.h>

#include <util/generic/variant.h>

using namespace NConfig;
using namespace NSrvKernel;
using namespace NModRedirects;

namespace NModRedirects {
    namespace {
        bool IsForward(const NProtoConfig::TKeyStack& keys) {
            using namespace NProtoConfig;
            if (keys.size() != 3) {
                return false;
            }
            if (!std::holds_alternative<TField>(keys[0]) || std::get<TField>(keys[0]).Name != "actions") {
                return false;
            }
            if (!std::holds_alternative<TIdx>(keys[1])) {
                return false;
            }
            if (!std::holds_alternative<TField>(keys[2]) || std::get<TField>(keys[2]).Name != "forward") {
                return false;
            }
            return true;
        }
    }
}

MODULE(redirects) {
public:
    TModule(const TModuleParams& moduleParams)
        : TModuleBase(moduleParams)
    {
        Config_ = ParseProtoConfig<TModuleConfig>(
            [&](const NProtoConfig::TKeyStack& keys, const TString& key, NConfig::IConfig::IValue* value) {
                if (!keys) {
                    Default_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
                } else if (IsForward(keys)){
                    using namespace NProtoConfig;
                    Forwards_.emplace(
                        std::get<TIdx>(keys[1]).Idx,
                        Loader->MustLoad(key, Copy(value->AsSubConfig())).Release()
                    );
                } else {
                    ythrow TConfigParseError() << "Unknown field at " << keys;
                }
            }
        );

        for (auto i : xrange(Config_.actions().size())) {
            auto&& act = Config_.actions()[i];

            Y_ENSURE_EX(act.src(),
                TConfigParseError() << "Action " << i << " has empty src");

            if (act.has_forward()) {
                auto* fwd = Forwards_.FindPtr(i);
                Y_ENSURE_EX(fwd,
                    TConfigParseError() << "Action " << i << " has empty forward");
                Redirects_.AddForward(act.src(), act.forward(), **fwd);
            } else if (act.has_redirect()) {
                Y_ENSURE_EX(act.redirect().code() / 100 == 3,
                    TConfigParseError() << "Action " << i << " expected 3xx code, got " << act.redirect().code());
                Redirects_.AddRedirect(act.src(), act.redirect());
            } else {
                ythrow TConfigParseError() << "Action " << i << " must be either forward or redirect";
            }
        }

        Redirects_.Compile();

        Y_ENSURE_EX(Default_,
            TConfigParseError() << "no submodule configured");
    }

private:
    TError DoRun(const TConnDescr& descr) const noexcept override {
        TStringBuf host = descr.Request->Headers().GetFirstValue("host");

        auto action = Redirects_.Location(
            host,
            descr.Request->RequestLine().Path.AsStringBuf(),
            descr.Request->RequestLine().CGI.AsStringBuf()
        );

        return std::visit(TOverloaded{
            [&](std::monostate) {
                descr.ExtraAccessLog << " default";
                return Default_->Run(descr);
            },
            [&](const TRedirect& redir) {
                descr.ExtraAccessLog << " " << redir.Code << " to " << redir.Location;
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "redirect to " + redir.Location);
                Y_TRY(TError, error) {
                    TResponse resp = TResponseBuilder()
                        .Code(redir.Code)
                        .Header("location", redir.Location)
                        .Version11();
                    Y_PROPAGATE_ERROR(descr.Output->SendHead(std::move(resp), false, TInstant::Max()));
                    Y_PROPAGATE_ERROR(descr.Output->SendEof(TInstant::Max()));
                    return SkipAll(descr.Input, TInstant::Max());
                } Y_CATCH {
                    descr.Properties->ConnStats.ClientError += 1;
                    descr.ExtraAccessLog << " error";
                    descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "client error");
                    return error;
                }
                descr.ExtraAccessLog << " succ";
                return TError{};
            },
            [&](const TForward& fwd) {
                descr.ExtraAccessLog << " forward to " << fwd.Location;
                auto req = *descr.Request;
                auto url = NImpl::SplitUrl(fwd.Location);
                req.Headers().Delete("Host");
                req.Headers().Add("Host", TString(url.Host));
                req.RequestLine().Path = TStringStorage(TString(url.Path));
                req.RequestLine().CGI = TStringStorage(TString(url.Query));
                return fwd.Dst->Run(descr.Copy(&req));
            }
        }, action);
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

private:
    TModuleConfig Config_;
    THolder<IModule> Default_;
    THashMap<ui32, THolder<IModule>> Forwards_;
    TRedirects Redirects_;
};

IModuleHandle* NModRedirects::Handle() {
    return TModule::Handle();
}
