#include <balancer/serval/core/config.h>

#include <library/cpp/regex/pire/regexp.h>

#include <util/generic/variant.h>

struct THeaderAction {
    template <typename... Args>
    THeaderAction(Args&&... args)
        : V_(std::forward<Args>(args)...)
    {}
    THeaderAction(THeaderAction&) = default;
    THeaderAction(THeaderAction&&) = default;
    THeaderAction(const THeaderAction&) = default;

    using TAdd = std::tuple<bool, TString, NSv::TFunction<TMaybe<TStringBuf>(NSv::IStream&)>>;
    using THas = std::tuple<bool, TString, TVector<THeaderAction>>;
    using TDel = NRegExp::TFsm;
    std::variant<TAdd, THas, TDel> V_;
};

static TVector<THeaderAction> Parse(const YAML::Node& arg) {
    CHECK_NODE(arg, arg.IsSequence(), "must be a sequence of headers");
    TVector<THeaderAction> ret;
    for (const auto& header : arg) {
        if (header.IsScalar() && header.Tag() == "!erase") {
            try {
                auto re = NRegExp::TFsm(header.Scalar(), NRegExp::TFsm::TOptions().SetCharset(CODES_UTF8));
                if (auto prev = ret.size() ? std::get_if<NRegExp::TFsm>(&ret.back().V_) : nullptr) {
                    *prev = *prev | re;
                } else {
                    ret.emplace_back(std::move(re));
                }
            } catch (const std::exception& e) {
                FAIL_NODE(header, e.what());
            }
            continue;
        }

        CHECK_NODE(header, header.IsMap() && header.size() == 1, "must be `name: value`");

        auto k = header.begin()->first;
        auto v = header.begin()->second;

        CHECK_NODE(k, k.IsScalar(), "must be a string");
        CHECK_NODE(k, EqualToOneOf(k.Tag(), "?", "!weak", "!if-has", "!if-has-none"), "unsupported condition");

        bool weak = (k.Tag() == "!weak");
        auto name = TString(k.Scalar());

        if (k.Tag() == "!if-has") {
            ret.emplace_back(THeaderAction::THas{true, name, Parse(v)});
        } else if (k.Tag() == "!if-has-none") {
            ret.emplace_back(THeaderAction::THas{false, name, Parse(v)});
        } else if (v.Tag() == "!ip") {
            ret.emplace_back(THeaderAction::TAdd{weak, name, [](NSv::IStream& req) -> TMaybe<TStringBuf> {
                return req.Retain(req.Peer().Format());
            }});
        } else if (v.Tag() == "!port") {
            ret.emplace_back(THeaderAction::TAdd{weak, name, [](NSv::IStream& req) -> TMaybe<TStringBuf> {
                return req.Retain(ToString(req.Peer().Port()));
            }});
        } else if (v.Tag() == "!time") {
            ret.emplace_back(THeaderAction::TAdd{weak, name, [](NSv::IStream& req) -> TMaybe<TStringBuf> {
                return req.Retain(ToString(TInstant::Now().MicroSeconds()));
            }});
        } else if (v.Tag() == "!header") {
            CHECK_NODE(v, v.IsScalar(), "header name must be a string");
            ret.emplace_back(THeaderAction::TAdd{weak, name, [v = v.Scalar()](NSv::IStream& req) -> TMaybe<TStringBuf> {
                if (auto it = req.Head()->find(v); it != req.Head()->end()) {
                    return it->second;
                }
                return Nothing();
            }});
        } else if (v.Tag() == "!forward") {
            ret.emplace_back(THeaderAction::TAdd{weak, name, [name](NSv::IStream& req) -> TMaybe<TStringBuf> {
                if (auto it = req.Head()->find(name); it != req.Head()->end()) {
                    return it->second;
                }
                return Nothing();
            }});
        } else if (v.Tag() == "!cookie") {
            CHECK_NODE(v, v.IsScalar() && name != "cookie", "cannot modify cookie");
            ret.emplace_back(THeaderAction::TAdd{weak, name, [v = v.Scalar()](NSv::IStream& req) -> TMaybe<TStringBuf> {
                auto cs = ParseCookie(*req.Head());
                if (auto it = std::find_if(cs.begin(), cs.end(), [v = TStringBuf(v)](const auto& c) {
                    return c.first == v;
                }); it != cs.end()) {
                    return it->second.GetOrElse(TStringBuf());
                }
                return Nothing();
            }});
        } else {
            CHECK_NODE(v, v.Tag().size() <= 1, "unsupported function " << v.Tag());
            CHECK_NODE(v, v.IsScalar(), "must be a string"); // TODO lists?
            ret.emplace_back(THeaderAction::TAdd{weak, name, [v = v.Scalar()](NSv::IStream&) {
                return TStringBuf(v);
            }});
        }
    }
    return ret;
}

static void Apply(const std::vector<THeaderAction>& headers, NSv::IStream& req, NSv::THead& head) {
    for (const auto& it : headers) {
        std::visit([&](auto& action) {
            if constexpr (std::is_same<std::decay_t<decltype(action)>, THeaderAction::TAdd>::value) {
                const auto& [weak, k, v] = action;
                if (!weak || head.find(k) == head.end()) {
                    if (auto constructed = v(req)) {
                        head.emplace(k, *constructed);
                    }
                }
            } else if constexpr (std::is_same<std::decay_t<decltype(action)>, THeaderAction::THas>::value) {
                const auto& [value, name, then] = action;
                if ((head.find(name) != head.end()) == value) {
                    Apply(then, req, head);
                }
            } else {
                head.erase_if([&](const auto& h) {
                    auto matcher = NRegExp::TMatcher(action).Match(h.first);
                    auto accepted = matcher.MatchedRegexps();
                    return accepted.first != accepted.second;
                });
            }
        }, it.V_);
    }
}

static NSv::TAction RequestHeaders(const YAML::Node& args, NSv::TAuxData&) {
    CHECK_NODE(args, args.IsMap(), "an argument is required");
    return [headers = Parse(args.begin()->second)](NSv::IStreamPtr& req) {
        auto* rqh = req->Head();
        if (!rqh) {
            return false;
        }
        Apply(headers, *req, *rqh);
        return true;
    };
}

static NSv::TAction ResponseHeaders(const YAML::Node& args, NSv::TAuxData&) {
    class TResponseHeaders: public NSv::TStreamProxy {
    public:
        TResponseHeaders(NSv::IStreamPtr& request, const std::vector<THeaderAction>& headers)
            : NSv::TStreamProxy(std::move(request))
            , Headers_(headers)
        {
        }

        bool WriteHead(NSv::THead& head) noexcept override {
            if (!head.IsInformational()) {
                Apply(Headers_, *this, head);
            }
            return NSv::TStreamProxy::WriteHead(head);
        }

    private:
        const std::vector<THeaderAction>& Headers_;
    };

    CHECK_NODE(args, args.IsMap(), "an argument is required");
    return [headers = Parse(args.begin()->second)](NSv::IStreamPtr& req) {
        req = std::make_shared<TResponseHeaders>(req, headers);
        return true;
    };
}

SV_DEFINE_ACTION("request-headers", RequestHeaders);
SV_DEFINE_ACTION("response-headers", ResponseHeaders);
