#include <balancer/serval/contrib/cone/mun.h>
#include <balancer/serval/core/config.h>

#include <library/cpp/regex/pire/regexp.h>

namespace {
    struct TAddrMatcher {
    public:
        TAddrMatcher(const YAML::Node& match) {
            Add(match);
        }

        void Add(const YAML::Node& node) {
            if (!node.IsSequence()) {
                CHECK_NODE(node, node.IsScalar(), "must be a network address");
                TStringBuf addr = node.Scalar();
                TStringBuf mask;
                CHECK_NODE(node, addr.TrySplit('/', addr, mask), "must be address/mask; to match one IP, use /32 or /128");
                auto parsedAddr = NSv::IP::Parse(addr);
                auto parsedMask = NSv::IP::ParseMask(mask, parsedAddr.Data.Base.sa_family);
                CHECK_NODE(node, parsedAddr, "invalid address or unsupported address family");
                CHECK_NODE(node, parsedMask, "invalid mask or unsupported address family");
                Patterns.emplace_back(parsedAddr.Raw(*parsedMask), *parsedMask);
            } else {
                for (const auto& subnode : node) {
                    Add(subnode);
                }
            }
        }

        bool Match(const NSv::IP& addr) const {
            for (const auto& [expect, mask] : Patterns) {
                if (addr.Raw(mask) == expect) {
                    return true;
                }
            }
            return false;
        }

    private:
        TVector<std::tuple<NSv::IP::TRaw, NSv::IP::TRaw>> Patterns;
    };
}

static NRegExp::TFsm Compile(const YAML::Node& regex, bool i, bool g, bool p) {
    try {
        struct TFsmParser : NRegExp::TFsmParser<NPire::TNonrelocScanner> {
            // Need TFsmParser's protected constructor to transform the regex before compiling.
            TFsmParser(const TScanner& compiled)
                : NRegExp::TFsmParser<TScanner>(compiled)
            {}
        };

        CHECK_NODE(regex, !g || !p, "!gp is equivalent to !g");
        CHECK_NODE(regex, regex.IsScalar(), "should be a regexp");
        auto options = NRegExp::TFsm::TOptions().SetCharset(CODES_UTF8).SetCaseInsensitive(i).SetSurround(g);
        auto parsed = NRegExp::TFsmBase::Parse(regex.Scalar(), options, false);
        if (p) {
            parsed += NRegExp::TFsmBase::Parse("(/.*)?", options, false);
        }
        return TFsmParser(parsed.Canonize().Compile<NPire::TNonrelocScanner>());
    } catch (const std::exception& e) {
        FAIL_NODE(regex, e.what());
    }
}

static void AddMatches(const NRegExp::TFsm& regex, TStringBuf value, TVector<bool>& matches) {
    // TODO capture groups and pass them on somehow?
    auto matcher = NRegExp::TMatcher(regex).Match(value);
    for (auto accepted = matcher.MatchedRegexps(); accepted.first != accepted.second; accepted.first++) {
        matches[*accepted.first] = true;
    }
}

static NSv::TAction Match(const YAML::Node& args, NSv::TAuxData& aux) {
    CHECK_NODE(args, args.IsMap(), "`match` requires an argument");
    auto it = args.begin();
    CHECK_NODE(it->second, it->second.IsScalar(), "`match` argument must be a header name");
    auto header = it++->second.Scalar();
    bool isPath = (header == ":path");
    bool isMethod = (header == ":method");
    bool isSource = (header == ":source");
    // TODO matching query arguments

    CHECK_NODE(args, it != args.end(), "`match` must list at least one branch");
    TVector<NSv::TAction> actions(Reserve(args.size() - 1));
    TVector<bool> negated(Reserve(args.size() - 1));
    TVector<TAddrMatcher> addrs;
    NRegExp::TFsm joined = NRegExp::TFsm::False();
    for (; it != args.end(); it++) {
        bool i = false, g = false, p = false, n = false;
        for (char flag : TStringBuf(it->first.Tag()).Skip(1)) switch (flag) {
            case 'i': i = true; break;
            case 'g': g = true; break;
            case 'p': p = true; break;
            case 'n': n = true; break;
            default: FAIL_NODE(it->first, "unsupported flag " << flag);
        }
        CHECK_NODE(it->first, !isSource || (!i && !g && !p), "only !n is allowed for :source");
        if (isSource) {
            addrs.emplace_back(it->first);
        } else {
            joined = actions ? joined | Compile(it->first, i, g, p) : Compile(it->first, i, g, p);
        }
        actions.push_back(aux.Action(it->second));
        negated.push_back(n);
    }

    return [=, addrs = std::move(addrs), regex = std::move(joined)](NSv::IStreamPtr& req) {
        if (isSource) {
            if (auto addr = req->Peer()) {
                for (size_t i = 0; i < actions.size(); i++) {
                    if ((addrs[i].Match(addr) ^ negated[i]) && !actions[i](req)) {
                        return false;
                    }
                }
            }
            return true;
        }
        auto rqh = req->Head();
        if (!rqh) {
            return false;
        }
        TVector<bool> matches(actions.size());
        if (isPath) {
            AddMatches(regex, rqh->Path(), matches);
        } else if (isMethod) {
            AddMatches(regex, rqh->Method, matches);
        } else {
            for (auto range = rqh->equal_range(header); range.first != range.second; range.first++) {
                AddMatches(regex, range.first->second, matches);
            }
        }
        for (size_t i = 0; i < actions.size(); i++) {
            // Fallthrough to the next matched action if the response isn't complete yet.
            if ((matches[i] ^ negated[i]) && !actions[i](req)) {
                return false;
            }
        }
        return true;
    };
}

SV_DEFINE_ACTION("match", Match);
