#include "storage.h"

#include <security/ant-secret/internal/regex_storage/storage.fbs.h>

#include <security/libs/cpp/log/log.h>
#include <library/cpp/resource/resource.h>
#include <util/generic/string.h>
#include <util/string/join.h>
#include <util/system/cpu_id.h>

namespace NSSInternal {
    namespace {
        constexpr TStringBuf kRegexResourceAvx2 = "/bundle-regex-avx2";

        constexpr TStringBuf kRegexResourceAny = "/bundle-regex-any";

        inline const TString
        makeKey(const TVector<const char*>& patterns,
                const TVector<unsigned int>& flags,
                const TVector<unsigned int>& ids) {
            return TString::Join(
                JoinSeq(TString(), patterns), JoinSeq(TString(), flags), JoinSeq(TString(), ids));
        }

        TString FindResource(const TStringBuf name) {
            TString data;
            try {
                data = NResource::Find(name);
            } catch (...) {
                // pass
            }
            return data;
        }

        TString FindResourceByCpu() {
            auto platform = NHyperscan::CurrentPlatform();
            if (platform == NHyperscan::EPlatform::AVX2 || platform == NHyperscan::EPlatform::AVX512) {
                // first of all check avx2/avx512 resource
                auto data = FindResource(kRegexResourceAvx2);
                if (data) {
                    return data;
                }
            }

            // fallback
            return FindResource(kRegexResourceAny);
        }

    }

    TRegexStorage& TRegexStorage::Instance() {
        static TRegexStorage instance;
        return instance;
    }

    TRegexHolder
    TRegexStorage::GetOrCompile(const TString& name,
                                const TVector<const char*>& patterns,
                                const TVector<unsigned int>& flags,
                                const TVector<unsigned int>& ids) {
        const auto& key = makeKey(patterns, flags, ids);
        with_lock (lock) {
            const auto it = cache.find(key);
            if (it != cache.end()) {
                return TRegexHolder(it->second);
            }

            NSecurityHelpers::LogWarn("Serialized regex DB not found, recompiling", "name", name);
            auto&& [item, ok] = cache.emplace(
                std::piecewise_construct,
                std::forward_as_tuple(key),
                std::forward_as_tuple(patterns, flags, ids, platform));

            if (Y_UNLIKELY(!ok)) {
                ythrow TSystemError() << "failed to cache regex";
            }
            return TRegexHolder(item->second);
        }
    }

    TRegexStorage::TRegexStorage()
        : platform(NHyperscan::CurrentPlatform())
    {
        if (Y_UNLIKELY(!initDb())) {
            NSecurityHelpers::LogWarn(
                "Serialized versions of regexes can't be initialized, do you run preserialize-regex?");
        }
    }

    bool NSSInternal::TRegexStorage::initDb() {
        TString data = FindResourceByCpu();
        if (!data) {
            return false;
        }

        auto rs = NRegexStorage::GetRegexes(reinterpret_cast<const ui8*>(data.data()));
        for (size_t i = 0; i < rs->Regs()->size(); ++i) {
            auto h = rs->Regs()->Get(i);
            cache.emplace(
                TStringBuf(reinterpret_cast<const char*>(h->Key()->data()), h->Key()->size()),
                TStringBuf(reinterpret_cast<const char*>(h->Db()->data()), h->Db()->size()));
        }

        return true;
    }

    const TString TRegexStorage::DumpDb() {
        flatbuffers::FlatBufferBuilder builder(1024);
        TVector<flatbuffers::Offset<NRegexStorage::Regex>>
            regexes;
        for (auto&& [key, item] : cache) {
            auto serialized = NHyperscan::Serialize(item.db);
            auto fbKey = builder.CreateVector(reinterpret_cast<const i8*>(key.data()), key.size());
            auto fbDb = builder.CreateVector(reinterpret_cast<const i8*>(serialized.data()), serialized.size());
            regexes.push_back(NRegexStorage::CreateRegex(builder, fbKey, fbDb));
        }
        auto h = NRegexStorage::CreateRegexes(builder, builder.CreateVector(regexes));
        builder.Finish(h);
        return TString(reinterpret_cast<char*>(builder.GetBufferPointer()), builder.GetSize());
    }

}
