#include "redirects.h"
#include "url_parts.h"

#include <balancer/kernel/module/module_face.h>

#include <util/generic/overloaded.h>

#include <util/string/subst.h>
#include <util/string/strip.h>

#include <string>

namespace NModRedirects {
    using namespace NSrvKernel;

    namespace NImpl {

        TRewrite::TRewrite(const TRewriteConfig& cfg)
            : Parts(cfg.url())
            , Regexp(cfg.regexp())
            , Pattern(SubstGlobalCopy(cfg.rewrite(), "$", "\\"))
            , Global(cfg.global())
        {}

        TString TRewrite::Apply(TStringBuf loc) const noexcept {
            auto parts = SplitUrl(loc);
            TString toRewrite = MergeUrl(ReplaceUrlParts({}, parts, Parts));
            if (Global) {
                re2::RE2::GlobalReplace(&toRewrite, Regexp, Pattern);
            } else {
                re2::RE2::Replace(&toRewrite, Regexp, Pattern);
            }
            return MergeUrl(ReplaceUrlParts(parts, SplitUrl(toRewrite), Parts));
        }

        TResult GenResult(const TDst& dst, TStringBuf path, TStringBuf query) noexcept {
            if (dst.LegacyRStrip) {
                path = StripString(path, TEqualsStripAdapter('/'));
                query = StripString(query, TEqualsStripAdapter('&'));
            } else {
                path = StripStringLeft(path, TEqualsStripAdapter('/'));
                query = StripStringLeft(query, TEqualsStripAdapter('&'));
            }

            TString loc(Reserve(dst.ToReserve + path.size() + query.size()));

            for (auto&& tok : dst.Location) {
                std::visit(TOverloaded{
                    [&](const TString& s) {
                        loc.append(s);
                    },
                    [&](EPlace p) {
                        switch (p) {
                        case EPlace::Path:
                            loc.append(path);
                            break;
                        case EPlace::S_Path:
                            if (path) {
                                loc.append('/').append(path);
                            }
                            break;
                        case EPlace::Query:
                            loc.append(path);
                            break;
                        case EPlace::A_Query:
                            if (query) {
                                loc.append('&').append(query);
                            }
                            break;
                        case EPlace::Q_Query:
                            if (query) {
                                loc.append('?').append(query);
                            }
                            break;
                        }
                    },
                }, tok);
            }

            for (auto&& rwr : dst.DstRewrites) {
                loc = rwr.Apply(loc);
            }

            return std::visit(TOverloaded{
                [&](TForwardAction fwd) {
                    return TResult(TForward{
                        .Location=std::move(loc),
                        .Dst=fwd.Dst
                    });
                },
                [&](TRedirectAction rdr) {
                    return TResult(TRedirect{
                        .Location=std::move(loc),
                        .Code=rdr.Code
                    });
                },
                [](std::monostate) {
                    // Should never happen
                    return TResult();
                }
            }, dst.Action);
        }

        TTemplate GenLocation(const TStringBuf rawDst) {
            static constexpr TStringBuf pPlace = "{path}";
            static constexpr TStringBuf qPlace = "{query}";
            static const re2::RE2 requestUri("[$]request_uri\\b");
            static const re2::RE2 args("[$]args\\b");

            TString dst(rawDst);

            re2::RE2::GlobalReplace(&dst, requestUri, {pPlace.data(), pPlace.size()});
            re2::RE2::GlobalReplace(&dst, args, {qPlace.data(), qPlace.size()});

            TTemplate tokens;
            auto append = [&](TStringBuf token) {
                if (!token) {
                    return;
                }
                if (tokens && std::holds_alternative<TString>(tokens.back())) {
                    std::get<TString>(tokens.back()).append(token);
                } else {
                    tokens.emplace_back(ToString(token));
                }
            };

            auto ppPos = dst.find(pPlace);
            auto ppEnd = ppPos != TString::npos ? ppPos + pPlace.size() : 0;
            auto qpPos = dst.find(qPlace);
            auto qpEnd = qpPos != TString::npos ? qpPos + qPlace.size() : 0;
            auto qPos = dst.find('?');
            auto fPos = dst.find('#');

            if (fPos < qPos) {
                qPos = TString::npos;
            }

            Y_ENSURE_EX(ppPos == dst.rfind(pPlace),
                TConfigParseError() << "Multiple path rewrites are not supported: " << TString(rawDst).Quote());
            Y_ENSURE_EX(qpPos == dst.rfind(qPlace),
                TConfigParseError() << "Multiple query rewrites are not supported: " << TString(rawDst).Quote());
            Y_ENSURE_EX(ppPos == TString::npos || ppPos < std::min({fPos, qPos, qpPos}),
                TConfigParseError() << "Inserting path into query or fragment is not supported: " << TString(rawDst).Quote());
            Y_ENSURE_EX(qPos == TString::npos || qPos < qpPos,
                TConfigParseError() << "Inserting query into path is not supported: " << TString(rawDst).Quote());
            Y_ENSURE_EX(qpPos == TString::npos || qpPos < fPos,
                TConfigParseError() << "Inserting query into fragment is not supported: " << TString(rawDst).Quote());
            Y_ENSURE_EX(qpPos == TString::npos || qPos != TString::npos || qpEnd == std::min({dst.size(), fPos}),
                TConfigParseError() << "Query placeholder without ? restricted to end: " << TString(rawDst).Quote());

            TString fragment = LegacySubstr(dst, fPos);
            dst = dst.substr(0, fPos);

            if (ppPos != TString::npos) {
                append(LegacySubstr(dst, 0, ppPos));
                if (ppPos > 0 && dst[ppPos - 1] != '/') {
                    tokens.emplace_back(EPlace::S_Path);
                } else {
                    tokens.emplace_back(EPlace::Path);
                }
            }

            if (qpPos == TString::npos) {
                append(LegacySubstr(dst, ppEnd));
            } else {
                append(LegacySubstr(dst, ppEnd, qpPos - ppEnd));
                if (qPos == TString::npos) {
                    tokens.emplace_back(EPlace::Q_Query);
                } else if (qpPos > 0 && !IsIn(TStringBuf("&?"), dst[qpPos - 1])) {
                    tokens.emplace_back(EPlace::A_Query);
                } else {
                    tokens.emplace_back(EPlace::Query);
                }
                append(LegacySubstr(dst, qpEnd));
            }

            append(fragment);
            return tokens;
        }

        TString GenSrc(TStringBuf src, bool expectsWcard) {
            const TString full(src);
            Y_ENSURE_EX(src.SkipPrefix("//") || src.SkipPrefix("http://") || src.SkipPrefix("https://"),
                TConfigParseError() << "Redirect src must start with scheme: " << full.Quote());
            Y_ENSURE_EX(src.find('#') == TStringBuf::npos,
                TConfigParseError() << "Redirect src cannot contain #: " << full.Quote());

            TStringBuf path, query;

            auto qPos = src.find('?');
            query = src.SubStr(qPos);
            src = src.SubStr(0, qPos);
            auto pPos = src.find('/');
            path = src.SubStr(pPos);
            src = src.SubStr(0, pPos);

            auto wcard = path.find("/*");
            Y_ENSURE_EX(wcard == TStringBuf::npos || wcard == path.size() - 2 && !query,
                TConfigParseError() << "redirect src can only have /* in the end: " << full.Quote());

            Y_ENSURE_EX(path.find("//") == TStringBuf::npos,
                TConfigParseError() << "redirect src cannot have // in path: " << full.Quote());

            path = StripStringRight(path, TEqualsStripAdapter('/'));

            if (!path) {
                path = "/";
            }

            TString fixed(src);
            fixed.append(path);
            if (wcard == TStringBuf::npos && expectsWcard) {
                if (!fixed.EndsWith('/')) {
                    fixed.append('/');
                }
                fixed.append('*');
            }
            if (query) {
                fixed.append(query);
            }
            return fixed;
        }
    }

    void TRedirects::AddRedirect(const TString& src, const TRedirectConfig& cfg) {
        const TStringBuf dst = cfg.dst();
        Y_ENSURE_EX(dst.StartsWith("https://") || dst.StartsWith("http://") || dst.StartsWith("//"),
            TConfigParseError() << "Redirect dst must start with scheme: " << cfg.dst().Quote());

        AddRoute(src, cfg.dst(), cfg.dst_rewrites(), NImpl::TRedirectAction{
            .Code = (ui16)cfg.code()
        }, cfg.legacy_rstrip());
    }

    void TRedirects::AddForward(const TString& src, const TForwardConfig& cfg, const IModule& fwd) {
        const TStringBuf dst = cfg.dst();
        Y_ENSURE_EX(!dst.StartsWith("https://"),
            TConfigParseError() << "Forward dst does not support https: " << cfg.dst().Quote());
        Y_ENSURE_EX(dst.StartsWith("http://") || dst.StartsWith("//"),
            TConfigParseError() << "Forward dst must start with scheme: " << cfg.dst().Quote());

        AddRoute(src, cfg.dst(), cfg.dst_rewrites(), NImpl::TForwardAction{
            .Dst = &fwd
        }, cfg.legacy_rstrip());
    }

    void TRedirects::AddRoute(
        const TString& srcLoc,
        const TString& dstLoc,
        const TVector<TRewriteConfig>& rwr,
        const NImpl::TAction& act,
        bool legacyRStrip
    ) {
        using namespace NImpl;

        TDst dst;
        dst.Action = act;
        dst.Location = GenLocation(dstLoc);
        dst.DstRewrites = TDeque<TRewrite>(rwr.begin(), rwr.end());
        dst.LegacyRStrip = legacyRStrip;

        bool expectsWcard = false;
        for (auto&& tok : dst.Location) {
            std::visit(TOverloaded{
                [&](const TString& s) {
                    dst.ToReserve += s.size();
                },
                [&](EPlace p) {
                    switch (p) {
                    case EPlace::Path:
                        expectsWcard = true;
                        break;
                    case EPlace::S_Path:
                        expectsWcard = true;
                        [[fallthrough]];
                    case EPlace::A_Query:
                    case EPlace::Q_Query:
                        dst.ToReserve += 1;
                        break;
                    default:
                        break;
                    }
                }
            }, tok);
        }

        auto src = GenSrc(srcLoc, expectsWcard);

        Y_ENSURE_EX(Builder_.Add(src, Dsts_.size()),
            TConfigParseError() << "Duplicate src found: " << src.Quote());

        Dsts_.emplace_back(std::move(dst));
    }

    void TRedirects::Compile() {
        TBufferOutput bout;
        Builder_.Save(bout);
        Srcs_ = TTrie(TBlob::FromBuffer(bout.Buffer()));
    }

    TResult TRedirects::Location(const TStringBuf host, TStringBuf path, TStringBuf query) const noexcept {
        constexpr TStringBuf pathSep = "/";
        constexpr TStringBuf pathWcard = "/*";
        constexpr TStringBuf querySep = "?";

        std::string lowerHost;
        lowerHost.resize_uninitialized(host.size());
        for (size_t i = 0; i < host.size(); ++i) {
            lowerHost[i] = AsciiToLower(host[i]);
        }

        auto suffs = Srcs_.FindTails(lowerHost);

        path = StripStringLeft(path, TEqualsStripAdapter('/'));
        query = StripStringLeft(StripStringLeft(query, TEqualsStripAdapter('?')), TEqualsStripAdapter('&'));

        TTrie wcardSuff;
        TStringBuf wcardPath;

        while (true) {
            if (auto next = suffs.FindTails(pathWcard); !next.IsEmpty()) {
                wcardSuff = next;
                wcardPath = path;
            }

            TStringBuf dir, nextPath;
            path.Split('/', dir, nextPath);
            // Implements nginx merge_slashes
            nextPath = StripStringLeft(nextPath, TEqualsStripAdapter('/'));

            if (auto nextSuffs = suffs.FindTails(pathSep).FindTails(dir); !nextSuffs.IsEmpty()) {
                suffs = nextSuffs;
                path = nextPath;
            } else {
                break;
            }
        }

        // The priorities of lookups:
        // 1. path + query
        // 2. path
        // 3. pathWcard
        // 4. nothing

        if (!path) {
            if (query) {
                auto q = StripStringRight(query, TEqualsStripAdapter('&'));
                // trying the full path + query first
                if (auto act = FindDst(suffs.FindTails(querySep), q)) {
                    return GenResult(*act, {}, query);
                }
            }
            // nothing? ok, then just the full path
            for (auto p : {TStringBuf("/"), TStringBuf()}) {
                if (auto act = FindDst(suffs, p)) {
                    return GenResult(*act, p, query);
                }
            }
        }

        // still nothing? ok, wcard must handle this
        if (auto act = FindDst(wcardSuff, {})) {
            return GenResult(*act, wcardPath, query);
        }

        // only possible if wcard was not available
        return {};
    }

    const NImpl::TDst* TRedirects::FindDst(TRedirects::TTrie suffixes, TStringBuf suffix) const noexcept {
        if (ui64 res = -1; suffixes.Find(suffix, &res)) {
            return &Dsts_[res];
        }
        return nullptr;
    }
}

template <>
void Out<NModRedirects::TRedirect>(IOutputStream& out, const NModRedirects::TRedirect& redir) {
    out << "{location=" << redir.Location.Quote() << ";code=" << redir.Code << "}";
}

template <>
void Out<NModRedirects::TForward>(IOutputStream& out, const NModRedirects::TForward& fwd) {
    out << "{location=" << fwd.Location.Quote() << ";dst={" << fwd.Dst->GetHandle()->Name() << "={...}}";
}

template <>
void Out<NModRedirects::TResult>(IOutputStream& out, const NModRedirects::TResult& res) {
    std::visit(TOverloaded{
        [&](std::monostate) {
            out << "(empty)";
        },
        [&](auto&& s) {
            out << s;
        },
    }, res);
}

template <>
void Out<NModRedirects::NImpl::TTemplate>(IOutputStream& out, const NModRedirects::NImpl::TTemplate& tmpl) {
    for (auto&& tok : tmpl) {
        out << tok;
    }
}

template <>
void Out<NModRedirects::NImpl::TTemplate::value_type>(IOutputStream& out, const NModRedirects::NImpl::TTemplate::value_type& tok) {
    std::visit(TOverloaded{
        [&](auto&& s) {
            out << s;
        },
    }, tok);
}
