#include "func.h"

#include "libretto.h"
#include "result.h"
#include "scenario.h"
#include "variables.h"
#include "utils/regex_escape.h"

#include <passport/infra/libs/cpp/utils/ipaddr.h>
#include <passport/infra/libs/cpp/utils/string/coder.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <util/generic/string.h>

#include <regex>

namespace NPassport::NLast {
    void TTestContext::Clear() {
        Headers.clear();
        Cookies.clear();
        Cgis.clear();
        Path.clear();
        Schema.clear();
        ErrMsg.clear();
        Vars.clear();
    }

    const TString TArg::SUB_FUNC = "sub_func";
    const TString TArg::PLAIN_TEXT = "plaintext";
    const TString TArg::REGEX = "regex";
    const TString TArg::RELATION = "relation";
    const TString TArg::STARTS_WITH = "starts_with";
    const TString TArg::CONTAINS = "contains";

    const TString TArg::NONE = "none";
    const TString TArg::KEY_COUNT = "key_count";
    const TString TArg::ARRAY_SIZE = "array_size";

    TArg::TArg(const TString& value)
        : CheckFunc_(&TArg::PlainCheck)
        , ResMod_(None)
        , NeedSubstitute_(false)
        , ExpectedValue_(value)
    {
    }

    TArg::TArg(const NLibretto::TArg& xarg)
        : CheckFunc_(&TArg::PlainCheck)
        , ResMod_(None)
        , NeedSubstitute_(false)
        , ExpectedValue_(xarg.ExpectedValue)
    {
        if (xarg.SubItem) {
            InitSubFunc(*xarg.SubItem);
        } else {
            InitNaturalArg(xarg);
        }
    }

    void TArg::InitNaturalArg(const NLibretto::TArg& xarg) {
        if (xarg.Comparator) {
            SetComp(xarg.Comparator, xarg.CompMin, xarg.CompMax);
        }

        if (xarg.ResultModificator) {
            SetResMod(xarg.ResultModificator);
        }

        NeedSubstitute_ = xarg.DoesNeedSubstitute;
    }

    void TArg::InitSubFunc(const NLibretto::TResultItem& xarg) {
        SubFunc_ = std::make_shared<TDynamicResultItem>(xarg);
    }

    static const std::regex EXP_VAR_PATTERN("\\\\~`(%%|@@)?((?:(cgi|cookie|header):([[:alnum:]_]+)(?:\\{([0-9]+)\\})?)|path|schema)`");
    TString SubstituteVars(const TString& src, const TTestContext& ctx, bool regexSafe) {
        TString ret(src);
        TString type;
        TString name;
        TString value;
        std::smatch groups;
        while (std::regex_search(ret.cbegin(), ret.cend(), groups, EXP_VAR_PATTERN)) {
            type = groups.str(3);
            if (!type.empty()) {
                name = groups.str(4);
                TTestContext::TValueSet::const_iterator it;
                if (type == "cgi") {
                    it = ctx.Cgis.find(name);
                    if (it == ctx.Cgis.end()) {
                        goto undef_ref;
                    }
                } else if (type == "header") {
                    it = ctx.Headers.find(name);
                    if (it == ctx.Headers.end()) {
                        goto undef_ref;
                    }
                } else {
                    it = ctx.Cookies.find(name);
                    if (it == ctx.Cookies.end()) {
                        goto undef_ref;
                    }
                }
                value = it->second;

                TString len(groups.str(5));
                if (!len.empty()) { // make a substring of given length
                    unsigned l = NUtils::ToUInt(len, "substring length");
                    value = value.substr(0, l);
                }
            } else {
                type = groups.str(2);
                if (type == "path") {
                    value = ctx.Path;
                } else {
                    value = ctx.Schema;
                }
            }
            //        Cout << "MATCH: type=<" << type << ">, name=<" << name << ">, value=<" << value << ">\n";
            //        Cout << "SUBSTITUTE: <" << ret.substr (groups[0].first, groups[0].second - groups[0].first) << ">\n";
            TString encode(groups.str(1));
            if (encode == "%%") { // urlencode value before check
                value = NUtils::Urlencode(value);
            } else if (encode == "@@") { // treat value as IP and convert to b64url format
                NUtils::TIpAddr addr;
                if (addr.Parse(value)) {
                    value = addr.ToBase64String();
                }
            }
            ret.replace(groups.position(0), groups.length(0), regexSafe ? RegexEscape(value) : value);
        }

        return ret;

    undef_ref:
        throw TLastError() << "regex '" << src << "' refers to undefined variable '"
                           << groups.str(3) << "'";
        return ret;
    }

    bool TArg::Match(const TString& value, const TTestContext& ctx) const {
        // Sub func
        if (SubFunc_) {
            return SubFunc_->Match(ctx, value);
        }

        // Natural arg
        if (NeedSubstitute_) {
            return (this->*CheckFunc_)(value, SubstituteVars(ExpectedValue_, ctx));
        }
        return (this->*CheckFunc_)(value, ExpectedValue_);
    }

    TString TArg::CheckInfo() const {
        TString res("res_mod=");
        switch (ResMod_) {
            case None:
                res.append(NONE);
                break;
            case KeyCount:
                res.append(KEY_COUNT);
                break;
            case ArraySize:
                res.append(ARRAY_SIZE);
                break;
        }
        res.append("; ");

        if (SubFunc_) {
            res += SUB_FUNC;
        } else if (CheckFunc_ == &TArg::PlainCheck) {
            res += PLAIN_TEXT;
        } else if (CheckFunc_ == &TArg::RegexCheck) {
            res += REGEX;
        } else if (CheckFunc_ == &TArg::RelationCheck) {
            res.append("relation: ");
            if (Min_.first) {
                res.append(IntToString<10>(Min_.second)).append(" <= ");
            }
            res.append("X");
            if (Max_.first) {
                res.append(" >= ").append(IntToString<10>(Max_.second));
            }
        } else if (CheckFunc_ == &TArg::StartsWithCheck) {
            res += STARTS_WITH;
        } else if (CheckFunc_ == &TArg::ContainsCheck) {
            res += CONTAINS;
        }

        if (NeedSubstitute_) {
            res.append(": ").append(LastFailed_);
        }

        return res;
    }

    void TArg::SetCompRegex() {
        CheckFunc_ = &TArg::RegexCheck;
    }

    void TArg::SetComp(const TString& comp, const std::optional<ui64> min, const std::optional<ui64> max) {
        if (comp == REGEX) {
            CheckFunc_ = &TArg::RegexCheck;
        } else if (comp == RELATION) {
            Y_ENSURE(min || max, "'comp=relation' requires '@min' or '@max'");
            CheckFunc_ = &TArg::RelationCheck;
            Min_.first = bool(min);
            if (Min_.first) {
                Min_.second = *min;
            }
            Max_.first = bool(max);
            if (Max_.first) {
                Max_.second = *max;
            }
        } else if (comp == PLAIN_TEXT) {
            CheckFunc_ = &TArg::PlainCheck;
        } else if (comp == STARTS_WITH) {
            CheckFunc_ = &TArg::StartsWithCheck;
        } else if (comp == CONTAINS) {
            CheckFunc_ = &TArg::ContainsCheck;
        } else {
            throw yexception() << "Unknow comparator:" << comp;
        }
    }

    void TArg::SetResMod(const TString& resMod) {
        if (resMod == NONE) {
            ResMod_ = None;
        } else if (resMod == KEY_COUNT) {
            ResMod_ = KeyCount;
        } else if (resMod == ARRAY_SIZE) {
            ResMod_ = ArraySize;
        } else {
            throw yexception() << "Unknow result modifier: " << resMod;
        }
    }

    bool TArg::RegexCheck(const TString& value, const TString& expected) const {
        if (NeedSubstitute_) {
            bool res = std::regex_search(value.cbegin(), value.cend(), std::regex(expected.cbegin(), expected.cend()));
            if (!res) {
                LastFailed_ = expected;
            }
            return res;
        }

        return std::regex_search(value.cbegin(), value.cend(), std::regex(expected.cbegin(), expected.cend()));
    }

    bool TArg::RelationCheck(const TString& value, const TString&) const {
        long valueInt = strtol(value.c_str(), nullptr, 10);
        if (Min_.first && Min_.second > valueInt) {
            return false;
        }
        return !(Max_.first && valueInt > Max_.second);
    }

    bool TArg::PlainCheck(const TString& value, const TString& expected) const {
        bool res = value == expected;
        if (!res && NeedSubstitute_) {
            LastFailed_ = expected;
        }
        return res;
    }

    bool TArg::StartsWithCheck(const TString& value, const TString& expected) const {
        bool res = value.compare(0, expected.size(), expected) == 0;
        if (!res && NeedSubstitute_) {
            LastFailed_ = expected;
        }
        return res;
    }

    bool TArg::ContainsCheck(const TString& value, const TString& expected) const {
        bool res = value.find(expected) != TString::npos;
        if (!res && NeedSubstitute_) {
            LastFailed_ = expected;
        }
        return res;
    }
}
