#include "pcre_wrapper.h"

#include <util/memory/tempbuf.h>

#include <mail/so/libs/talkative_config/config.h>
#include <contrib/libs/pcre/pcre.h>
#include <util/generic/maybe.h>
#include <util/stream/printf.h>
#include <util/string/printf.h>
#include <library/cpp/charset/recyr.hh>

template<> void TDelete::Destroy<pcre>(pcre* res) noexcept {
    pcre_free(res);
}
template<> void TDelete::Destroy<pcre_extra>(pcre_extra* res) noexcept {
    pcre_free_study(res);
}

namespace NRegexp {

    TSettings TSettings::FromConfig(const NConfig::TDict& section) {
        TSettings settings;

        if(const NConfig::TConfig* config = MapFindPtr(section, "match-limit")) {
            settings.MatchLimit = NTalkativeConfig::As<size_t>(*config);
        }

        if(const NConfig::TConfig* config = MapFindPtr(section, "match-limit-recursion")) {
            settings.MatchLimitRecursion = NTalkativeConfig::As<size_t>(*config);
        }

        return settings;
    }

    class TPcreTables {
    public:
        static const unsigned char * maketables(ECharset charset)
        {
            const auto* cp = CodePageByCharset(charset);
            unsigned char *yield, *p;
            int i;
#define cbit_space     0      /* [:space:] or \s */
#define cbit_xdigit   32      /* [:xdigit:] */
#define cbit_digit    64      /* [:digit:] or \d */
#define cbit_upper    96      /* [:upper:] */
#define cbit_lower   128      /* [:lower:] */
#define cbit_word    160      /* [:word:] or \w */
#define cbit_graph   192      /* [:graph:] */
#define cbit_print   224      /* [:print:] */
#define cbit_punct   256      /* [:punct:] */
#define cbit_cntrl   288      /* [:cntrl:] */
#define cbit_length  320      /* Length of the cbits table */

#define ctype_space   0x01
#define ctype_letter  0x02
#define ctype_digit   0x04
#define ctype_xdigit  0x08
#define ctype_word    0x10   /* alphanumeric or '_' */
#define ctype_meta    0x80   /* regexp meta char or zero (end pattern) */
#define lcc_offset      0
#define fcc_offset    256
#define cbits_offset  512
#define ctypes_offset (cbits_offset + cbit_length)
#define tables_length (ctypes_offset + 256)

#ifndef DFTABLES
            yield = (unsigned char*)pcre_malloc(tables_length);
#else
            yield = (unsigned char*)malloc(tables_length);
#endif

            if (yield == NULL) return NULL;
            p = yield;

/* First comes the lower casing table */

            for (i = 0; i < 256; i++) *p++ = cp->ToLower(i);

/* Next the case-flipping table */

            for (i = 0; i < 256; i++) *p++ = cp->IsLower(i)? cp->ToUpper(i) : cp->ToLower(i);

/* Then the character class tables. Don't try to be clever and save effort on
exclusive ones - in some locales things may be different.

Note that the table for "space" includes everything "cp->IsSpace" gives, including
VT in the default locale. This makes it work for the POSIX class [:space:].
From release 8.34 is is also correct for Perl space, because Perl added VT at
release 5.18.

Note also that it is possible for a character to be alnum or alpha without
being lower or upper, such as "male and female ordinals" (\xAA and \xBA) in the
fr_FR locale (at least under Debian Linux's locales as of 12/2005). So we must
test for alnum specially. */

            memset(p, 0, cbit_length);
            for (i = 0; i < 256; i++)
            {
                if (cp->IsDigit(i)) p[cbit_digit  + i/8] |= 1 << (i&7);
                if (cp->IsUpper(i)) p[cbit_upper  + i/8] |= 1 << (i&7);
                if (cp->IsLower(i)) p[cbit_lower  + i/8] |= 1 << (i&7);
                if (cp->IsAlnum(i)) p[cbit_word   + i/8] |= 1 << (i&7);
                if (i == '_')   p[cbit_word   + i/8] |= 1 << (i&7);
                if (cp->IsSpace(i)) p[cbit_space  + i/8] |= 1 << (i&7);
                if (cp->IsXdigit(i))p[cbit_xdigit + i/8] |= 1 << (i&7);
                if (cp->IsGraph(i)) p[cbit_graph  + i/8] |= 1 << (i&7);
                if (cp->IsPrint(i)) p[cbit_print  + i/8] |= 1 << (i&7);
                if (cp->IsPunct(i)) p[cbit_punct  + i/8] |= 1 << (i&7);
                if (cp->IsCntrl(i)) p[cbit_cntrl  + i/8] |= 1 << (i&7);
            }
            p += cbit_length;

/* Finally, the character type table. In this, we used to exclude VT from the
white space chars, because Perl didn't recognize it as such for \s and for
comments within regexes. However, Perl changed at release 5.18, so PCRE changed
at release 8.34. */

            for (i = 0; i < 256; i++)
            {
                int x = 0;
                if (cp->IsSpace(i)) x += ctype_space;
                if (cp->IsAlpha(i)) x += ctype_letter;
                if (cp->IsDigit(i)) x += ctype_digit;
                if (cp->IsXdigit(i)) x += ctype_xdigit;
                if (cp->IsAlnum(i) || i == '_') x += ctype_word;

                /* Note: strchr includes the terminating zero in the characters it considers.
                In this instance, that is ok because we want binary zero to be flagged as a
                meta-character, which in this sense is any character that terminates a run
                of data characters. */

                if (strchr("\\*+?{^.$|()[", i) != 0) x += ctype_meta;
                *p++ = x;
            }

            return yield;
        }

        const ui8 *Get() const {
            return tables;
        }

        explicit TPcreTables() : tables(pcre_maketables()) {}

        explicit TPcreTables(ECharset charset) : tables(maketables(charset)) {}

        ~TPcreTables() {
            pcre_free((void *) tables);
        }

    private:
        const ui8 *tables{};
    };

    struct TPcreErrorCodeToString{
        explicit TPcreErrorCodeToString(int code) : code(code) {}
        int code{};
    };

    class TPcre::TImpl {
    public:

        [[nodiscard]] EMatchResult Match(const TStringBuf &text) const {
            if (!text.IsInited()) {
                return EMatchResult::NotMatch;
            }

            const int count = pcre_exec(re_comp.Get(), extra.Get(), text.data(), static_cast<int>(text.length()), 0,
                                        0, nullptr, 0);

            if (count < 0) {
                switch (count) {
                    case PCRE_ERROR_NOMATCH:
                    case PCRE_ERROR_MATCHLIMIT:
                        return EMatchResult::NotMatch;
                    default:
                        ythrow TError() << TPcreErrorCodeToString(count) << "re: " << originalRe << " text: " << text;
                }
            }

            return EMatchResult::Match;
        }

        EMatchResult Match(const TStringBuf &text, size_t maxN, TMatches &matches) const {
            if (!text.IsInited()) {
                return EMatchResult::NotMatch;
            }

            TVector<int> tempBuf((maxN+1) * 3, -1);

            int count = pcre_exec(
                    re_comp.Get(),
                    extra.Get(),
                    text.data(),
                    static_cast<int>(text.length()),
                    0,
                    0,
                    &tempBuf[0],
                    static_cast<int>(tempBuf.size())
            );

            if (count < 0) {
                switch (count) {
                    case PCRE_ERROR_NOMATCH:
                    case PCRE_ERROR_MATCHLIMIT:
                        return EMatchResult::NotMatch;
                    default:
                        ythrow TError() << TPcreErrorCodeToString(count) << "re: " << originalRe << " text: " << ", count = " << count << text;
                }
            }

            if(count == 0) {
                count = tempBuf.size() / 3;
            }

            matches.resize(0);
            matches.reserve(count);

            const auto *offsets = &tempBuf[0];

            for (size_t i = 0; i < size_t(count); i++) {
                const int offs1 = *offsets++;
                const int offs2 = *offsets++;

                const char *start = text.data() + offs1;
                const auto size = static_cast<const size_t>(offs2 - offs1);

                matches.emplace_back(start, size);
            }

            return EMatchResult::Match;
        }

        TImpl(const TStringBuf &re, const TSettings &settings) : settings(settings) {
            const char *pcreError = nullptr;
            int errorOffset = -1;

            THolder<pcre> re_comp_tmp;
            if(settings.Charset == CODES_UTF8) {
                static const TPcreTables table;
                re_comp_tmp.Reset(pcre_compile(re.data(), PCRE_UTF8, &pcreError, &errorOffset, table.Get()));
            } else {
                re_comp_tmp.Reset(pcre_compile(re.data(), 0, &pcreError, &errorOffset, Singleton<TPcreTables>(settings.Charset)->Get()));
            }


            if(pcreError)
                ythrow TError() << pcreError << " in position " << errorOffset << "; re:" << re;

            if (!re_comp_tmp)
                ythrow TError() << "re_comp is null";

            pcreError = nullptr;
            THolder<pcre_extra> extra_tmp(pcre_study(re_comp_tmp.Get(), PCRE_STUDY_EXTRA_NEEDED | PCRE_STUDY_JIT_COMPILE, &pcreError));
            if (pcreError != nullptr)
                ythrow TError() << pcreError;

            if (!extra_tmp)
                ythrow TError() << "extra is null";

            extra_tmp->match_limit = settings.MatchLimit;
            extra_tmp->match_limit_recursion = settings.MatchLimitRecursion;
            extra_tmp->flags |= PCRE_EXTRA_MATCH_LIMIT | PCRE_EXTRA_MATCH_LIMIT_RECURSION;

            re_comp = std::move(re_comp_tmp);
            extra = std::move(extra_tmp);
            originalRe.assign(re);
        }
    private:
        THolder<pcre> re_comp;
        THolder<pcre_extra> extra;
        TString originalRe;
        const TSettings settings;
    };

    EMatchResult TPcre::Match(const TStringBuf &text) const {
        return impl->Match(text);
    }

    EMatchResult TPcre::Match(const TStringBuf &text, size_t maxN, TMatches &matches) const {
        return impl->Match(text, maxN, matches);
    }

    TPcre::TPcre(const TStringBuf &re, const TSettings &settings) : impl(new TPcre::TImpl(re, settings)) {}

    TPcre::~TPcre() = default;

    IOutputStream& operator << (IOutputStream & stream, const TPcreErrorCodeToString & pcreError) {
        switch (pcreError.code) {
            case PCRE_ERROR_NOMATCH:
                return stream << "NOMATCH";
            case PCRE_ERROR_NULL:
                return stream << "NULL";
            case PCRE_ERROR_BADOPTION:
                return stream << "BADOPTION";
            case PCRE_ERROR_BADMAGIC:
                return stream << "BADMAGIC";
            case PCRE_ERROR_UNKNOWN_OPCODE:
                return stream << "UNKNOWN_OPCODE";
            case PCRE_ERROR_NOMEMORY:
                return stream << "NOMEMORY";
            case PCRE_ERROR_NOSUBSTRING:
                return stream << "NOSUBSTRING";
            case PCRE_ERROR_MATCHLIMIT:
                return stream << "MATCHLIMIT";
            case PCRE_ERROR_CALLOUT:
                return stream << "CALLOUT";
            case PCRE_ERROR_BADUTF8:
                return stream << "BADUTF8";
            case PCRE_ERROR_BADUTF8_OFFSET:
                return stream << "BADUTF8_OFFSET";
            case PCRE_ERROR_PARTIAL:
                return stream << "PARTIAL";
            case PCRE_ERROR_BADPARTIAL:
                return stream << "BADPARTIAL";
            case PCRE_ERROR_INTERNAL:
                return stream << "INTERNAL";
            case PCRE_ERROR_BADCOUNT:
                return stream << "BADCOUNT";
            case PCRE_ERROR_DFA_UITEM:
                return stream << "DFA_UITEM";
            case PCRE_ERROR_DFA_UCOND:
                return stream << "DFA_UCOND";
            case PCRE_ERROR_DFA_UMLIMIT:
                return stream << "DFA_UMLIMIT";
            case PCRE_ERROR_DFA_WSSIZE:
                return stream << "DFA_WSSIZE";
            case PCRE_ERROR_DFA_RECURSE:
                return stream << "DFA_RECURSE";
            case PCRE_ERROR_RECURSIONLIMIT:
                return stream << "RECURSIONLIMIT";
            case PCRE_ERROR_NULLWSLIMIT:
                return stream << "NULLWSLIMIT";
            case PCRE_ERROR_BADNEWLINE:
                return stream << "BADNEWLINE";
            case PCRE_ERROR_BADOFFSET:
                return stream << "BADOFFSET";
            case PCRE_ERROR_SHORTUTF8:
                return stream << "SHORTUTF8";
            case PCRE_ERROR_RECURSELOOP:
                return stream << "RECURSELOOP";
            case PCRE_ERROR_JIT_STACKLIMIT:
                return stream << "JIT_STACKLIMIT";
            case PCRE_ERROR_BADMODE:
                return stream << "BADMODE";
            case PCRE_ERROR_BADENDIANNESS:
                return stream << "BADENDIANNESS";
            case PCRE_ERROR_DFA_BADRESTART:
                return stream << "DFA_BADRESTART";
            case PCRE_ERROR_JIT_BADOPTION:
                return stream << "JIT_BADOPTION";
            case PCRE_ERROR_BADLENGTH:
                return stream << "BADLENGTH";
            case PCRE_ERROR_UNSET:
                return stream << "UNSET";
            default:
                return stream << "Unknown code";
        }
    }
}   //  namespace NRegexp
