#include "domain.h"

#include <balancer/kernel/cookie/utils/utils.h>

#include <library/cpp/containers/comptrie/comptrie.h>
#include <library/cpp/regex/pire/regexp.h>
#include <library/cpp/resource/resource.h>

#include <util/generic/singleton.h>
#include <util/generic/scope.h>
#include <util/memory/blob.h>
#include <util/stream/buffer.h>
#include <util/string/join.h>
#include <util/string/split.h>
#include <util/string/strip.h>

#include <contrib/libs/libidn2/include/idn2.h>

namespace NSrvKernel {
    using NRegExp::TFsm;
    using NRegExp::TMatcher;
    using NCookie::SkipDot;

    namespace {
        TString DecodeIdn2(TString domainName) {
            char* buf = nullptr;
            Y_DEFER {
                if (buf) {
                    idn2_free(buf);
                }
            };

            auto res = idn2_to_ascii_8z(domainName.c_str(), &buf, IDN2_NONTRANSITIONAL | IDN2_NFC_INPUT);
            Y_ENSURE(res == IDN2_OK,
                "Failed to convert " << domainName.Quote() << " to punicode: "
                << idn2_strerror_name(res) << " (" << idn2_strerror(res) << ")");

            return TString(buf);
        }

        class TPubSuffs {
            using TTrie = TCompactTrie<char, bool>;
        public:
            TPubSuffs() {
                TTrie::TBuilder builder;

                auto res = NResource::Find("/public_suffix_list.dat");

                for (auto lineIt : StringSplitter(res).Split('\n')) {
                    auto line = StripString(lineIt.Token());
                    if (!line || line.StartsWith("//")) {
                        continue;
                    }

                    bool value = !line.SkipPrefix("!");
                    bool wc = line.SkipPrefix("*");
                    Y_ENSURE(
                        (!wc || line.StartsWith('.'))
                        && line.find('*') == TStringBuf::npos
                        && line.find('!') == TStringBuf::npos,
                        "Unexpected record format: " << TString(line).Quote());

                    builder.Add(line, value);
                    builder.Add(DecodeIdn2(TString(line)), value);
                }

                TBufferOutput bout;
                builder.Save(bout);
                bout.Buffer().ShrinkToFit();
                PubSuffs_ = TTrie(TBlob::FromBuffer(bout.Buffer()));
            }

            TMaybe<TStringBuf> ShortestPrivateSuffix(const TStringBuf domainLowerCase) const noexcept {
                auto d = domainLowerCase;
                d.SkipPrefix(TStringBuf("."));

                if (!d) {
                    return TStringBuf();
                }

                if (d.StartsWith('.')) {
                    return Nothing();
                }

                if (d.ChopSuffix(TStringBuf(".")) && d.EndsWith('.')) {
                    return Nothing();
                }

                TStringBuf last;
                TStringBuf next = d;
                while (d) {
                    bool isPub = false;

                    if (PubSuffs_.Find(d, &isPub)) {
                        if (!isPub) {
                            last = d;
                        }
                        break;
                    }

                    d = d.SubStr(d.find('.'));
                    if (PubSuffs_.Find(d, &isPub)) {
                        break;
                    }
                    d.Skip(1);
                    if (d.StartsWith('.')) {
                        return Nothing();
                    }
                    if (d) {
                        last = next;
                        next = d;
                    }
                }
                return last ? TStringBuf(last.begin(), domainLowerCase.end()) : TStringBuf();
            }

        private:
            TTrie PubSuffs_; // true -> public, false/none -> private
        };

        TFsm SubdomainsFsm(TVector<TStringBuf> domains, TVector<TStringBuf> exclude={}) {
            auto tlds = JoinSeq("|", domains);
            if (exclude) {
                tlds = "(" + tlds + ")&~(" + JoinSeq("|", exclude) + ")";
            }
            return TFsm(
                "(.*[.])?(" + tlds + ")[.]?",
                TFsm::TOptions().SetCaseInsensitive(true).SetAndNotSupport(true)
            );
        }

        struct TCookieDomains {
            TPubSuffs PubSuffs;
            TFsm DomainName = TFsm(
                "([a-z0-9_\\-]+[.])*[a-z][a-z0-9\\-]*[.]?",
                TFsm::TOptions().SetCaseInsensitive(true)
            );
            TFsm NoCookieDomains = SubdomainsFsm({
                TStringBuf("yandex[.]net"),
                TStringBuf("yandex[.]st"),
                TStringBuf("yastat[.]net"),
                TStringBuf("yastatic[.]net"),
            });
            // The TLDs which might need more stringent GDPR policies.
            TFsm GdprTld = SubdomainsFsm({
                // All gTlds except IDNs
                TStringBuf("(([a-z0-9\\-]{3,})&~(xn--.*))"),
                TStringBuf("eu"), // EU
                // ccTlds used as gTlds
                TStringBuf("tm"),
                TStringBuf("cc"),
                TStringBuf("io"),
                TStringBuf("me"),
                TStringBuf("st"),
                // ccTlds of EU countries
                TStringBuf("at"), // Austria
                TStringBuf("be"), // Belgium
                TStringBuf("bg"), // Bulgaria
                TStringBuf("hr"), // Croatia
                TStringBuf("cy"), // Republic of Cyprus
                TStringBuf("cz"), // Czech Republic
                TStringBuf("dk"), // Denmark
                TStringBuf("ee"), // Estonia
                TStringBuf("fi"), // Finland
                TStringBuf("fr"), // France
                TStringBuf("de"), // Germany
                TStringBuf("gr"), // Greece
                TStringBuf("hu"), // Hungary
                TStringBuf("ie"), // Ireland
                TStringBuf("it"), // Italy
                TStringBuf("lv"), // Latvia
                TStringBuf("lt"), // Lithuania
                TStringBuf("lu"), // Luxembourg
                TStringBuf("mt"), // Malta
                TStringBuf("nl"), // Netherlands
                TStringBuf("pl"), // Poland
                TStringBuf("pt"), // Portugal
                TStringBuf("ro"), // Romania
                TStringBuf("sk"), // Slovakia
                TStringBuf("si"), // Slovenia
                TStringBuf("es"), // Spain
                TStringBuf("se"), // Sweden
                TStringBuf("uk"), // United Kingdom
                // ccTlds of countries joining the EU
                TStringBuf("mk"), // Macedonia
                TStringBuf("rs"), // Serbia
                // EU IDNs
                TStringBuf("xn--e1a4c"),   // .ею
                TStringBuf("xn--qxa6a"),   // .ευ
                // ccIDNs of EU countries
                TStringBuf("xn--90ae"),    // .бг
                TStringBuf("xn--qxam"),    // .ελ
                // IDNs of countries joining the EU
                TStringBuf("xn--d1alf"),   // .мкд
                TStringBuf("xn--90a3ac"),  // .срб
            }, {
                // Excluding the ccIDNs of the non-EU ex-USSR countries
                TStringBuf("xn--p1ai"),    // .рф
                TStringBuf("xn--j1amh"),   // .укр
                TStringBuf("xn--90ais"),   // .бел
                TStringBuf("xn--80ao21a"), // .қаз
                TStringBuf("xn--node"),    // .გე
                TStringBuf("xn--y9a3aq"),  // .հայ
                TStringBuf("xn--l1acc"),   // .мон
                TStringBuf("moscow"),
            });
        };
    }

    TMaybe<TStringBuf> ShortestPrivateSuffix(TStringBuf validDomainNameLowerCase) noexcept {
        return Default<TPubSuffs>().ShortestPrivateSuffix(validDomainNameLowerCase);
    }

    TMaybe<TStringBuf> LongestPublicSuffix(TStringBuf validDomainLowerCase) noexcept {
        if (auto privateSuff = Default<TPubSuffs>().ShortestPrivateSuffix(validDomainLowerCase)) {
            if (*privateSuff) {
                return privateSuff->After('.');
            } else {
                return SkipDot(validDomainLowerCase);
            }
        } else {
            return Nothing();
        }
    }

    bool ValidDomainName(TStringBuf domainStr) noexcept {
        return TMatcher(Default<TCookieDomains>().DomainName).Match(domainStr).Final();
    }

    bool NoCookieDomain(TStringBuf domainStr) noexcept {
        return TMatcher(Default<TCookieDomains>().NoCookieDomains).Match(domainStr).Final();
    }

    bool GdprDomain(TStringBuf domainStr) noexcept {
        return TMatcher(Default<TCookieDomains>().GdprTld).Match(domainStr).Final();
    }

    TMaybe<std::string> ReplacePublicSuffix(TStringBuf domainLowerCase, TStringBuf wcard) noexcept {
        wcard = SkipDot(wcard);
        if (!ValidDomainName(SkipDot(domainLowerCase))) {
            return Nothing();
        }
        auto publicSuff = LongestPublicSuffix(domainLowerCase);
        if (!publicSuff) {
            return Nothing();
        }
        auto d = domainLowerCase;
        d.ChopSuffix(*publicSuff);
        d.ChopSuffix(TStringBuf("."));
        if (!d) {
            return std::string(wcard);
        }
        return std::string(d).append(TStringBuf(".")).append(wcard);
    }

    void InitDomains() {
        Default<TCookieDomains>();
    }
}
