#include "hyperscan.h"

#include <contrib/libs/hyperscan/src/util/cpuid_flags.h>

namespace NHyperscan {

namespace {

using TCompileError = THolder<hs_compile_error_t, TDeleter<decltype(&hs_free_compile_error), &hs_free_compile_error>>;

inline hs_platform_info_t platformToTarget(EPlatform platform) {
    unsigned long long cpuFeatures = 0;
    switch (platform) {
        case EPlatform::AVX2:
            cpuFeatures = HS_CPU_FEATURES_AVX2;
            break;
        case EPlatform::AVX512:
            cpuFeatures = HS_CPU_FEATURES_AVX2 | HS_CPU_FEATURES_AVX512;
            break;
        case EPlatform::ANY:
        default:
            // no needed extra features
            break;
    }

    return hs_platform_info_t{
        .tune = cpuid_tune(),
        .cpu_features = cpuFeatures,
    };
}

}  // namespace anonymous


bool MatchesPtr(
    TDatabasePtr db,
    const TScratch& scratch,
    const TStringBuf& text
) {
    bool result = false;
    auto callback = [&](unsigned int /* id */, unsigned long long /* from */, unsigned long long /* to */) {
        result = true;
    };
    ScanPtr(
        db,
        scratch,
        text,
        callback);
    return result;
}

TDatabase CompileFor(
    const TStringBuf& regex,
    unsigned int flags,
    EPlatform platform
) {
    hs_database_t* rawDb = nullptr;
    hs_compile_error_t* rawCompileErr = nullptr;
    hs_platform_info_t target = platformToTarget(platform);
    hs_error_t status = hs_compile(
        regex.begin(),
        flags,
        HS_MODE_BLOCK,
        &target,
        &rawDb,
        &rawCompileErr);
    TDatabase db(rawDb);
    TCompileError compileError(rawCompileErr);
    if (status != HS_SUCCESS) {
        ythrow TCompileException()
            << "Failed to compile regex: " << regex << ". "
            << "Error message (hyperscan): " << compileError->message;
    }
    return db;
}

TDatabase CompileMultiFor(
    const TVector<const char*>& regexs,
    const TVector<unsigned int>& flags,
    const TVector<unsigned int>& ids,
    EPlatform platform) {
    auto count = static_cast<unsigned int>(regexs.size());
    if (flags.size() != count) {
        ythrow yexception()
            << "Mismatch of sizes vectors passed to CompileMulti. "
            << "size(regexs) = " << regexs.size() << ". "
            << "size(flags) = " << flags.size() << ".";
    }
    if (ids.size() != count) {
        ythrow yexception()
            << "Mismatch of sizes vectors passed to CompileMulti. "
            << "size(regexs) = " << regexs.size() << ". "
            << "size(ids) = " << ids.size() << ".";
    }
    hs_database_t* rawDb = nullptr;
    hs_compile_error_t* rawCompileErr = nullptr;
    hs_platform_info_t target = platformToTarget(platform);
    hs_error_t status = hs_compile_multi(
        regexs.data(),
        flags.data(),
        ids.data(),
        count,
        HS_MODE_BLOCK,
        &target,
        &rawDb,
        &rawCompileErr);
    TDatabase db(rawDb);
    TCompileError compileError(rawCompileErr);
    if (status != HS_SUCCESS) {
        if (compileError->expression >= 0) {
            const char* regex = regexs[compileError->expression];
            ythrow TCompileException()
                << "Failed to compile regex: " << regex << ". "
                << "Error message (hyperscan): " << compileError->message;
        } else {
            ythrow TCompileException()
                << "Failed to compile multiple regexs. "
                << "Error message (hyperscan): " << compileError->message;
        }
    }
    return db;
}

}  // namespace NHyperscan
