#include <passport/infra/libs/cpp/utils/file.h>
#include <passport/infra/libs/cpp/utils/crypto/hash.h>
#include <passport/infra/libs/cpp/utils/crypto/rsa.h>

#include <contrib/libs/libsodium/include/sodium/crypto_sign_ed25519.h>
#include <contrib/libs/openssl/include/openssl/bio.h>
#include <contrib/libs/openssl/include/openssl/err.h>
#include <contrib/libs/openssl/include/openssl/pem.h>
#include <contrib/libs/openssl/include/openssl/x509.h>

#include <library/cpp/openssl/holders/evp.h>
#include <library/cpp/openssl/init/init.h>

#include <util/datetime/base.h>
#include <util/stream/format.h>

#include <atomic>
#include <functional>
#include <future>
#include <vector>

struct TErrCleaner {
    TErrCleaner() {
        ERR_clear_error();
    }

    ~TErrCleaner() {
        ERR_clear_error();
    }
};

TStringBuf GetOpensslError() {
    return ERR_reason_error_string(ERR_get_error());
}

class IPrivateKey {
public:
    virtual ~IPrivateKey() = default;
    virtual TString Name() const = 0;
    virtual TString Sign(TStringBuf data) const = 0;
};

class IPublicKey {
public:
    virtual ~IPublicKey() = default;
    virtual TString Name() const = 0;
    virtual bool Verify(TStringBuf data, TStringBuf signature) const = 0;
};

class TRsaPrivateKey: public IPrivateKey {
public:
    TRsaPrivateKey(TStringBuf pem)
        : Key_(NPassport::NUtils::TRsaPrivateEvp::FromPem(pem))
    {
    }

    TString Name() const override {
        return "rsa_priv";
    }

    TString Sign(TStringBuf data) const override {
        return Key_.SignWithSha256(data);
    }

private:
    NPassport::NUtils::TRsaPrivateEvp Key_;
};

class TRsaPublicKey: public IPublicKey {
public:
    TRsaPublicKey(TStringBuf pem)
        : Key_(NPassport::NUtils::TRsaPublicEvp::FromPem(pem))
    {
    }

    TString Name() const override {
        return "rsa_pub";
    }

    bool Verify(TStringBuf data, TStringBuf signature) const override {
        return Key_.VerifyWithSha256(data, signature).IsSuccess;
    }

private:
    NPassport::NUtils::TRsaPublicEvp Key_;
};

template <typename Func>
static NPassport::NUtils::TEvpKey ParseKey(TStringBuf pem, int expected, Func func) {
    TErrCleaner cleaner;

    std::unique_ptr<BIO, decltype(&BIO_free)> bio(
        BIO_new_mem_buf(pem.data(), pem.size()),
        BIO_free);
    if (!bio) {
        throw std::bad_alloc();
    }

    NPassport::NUtils::TEvpKey res;
    res.Pkey.reset(func(bio.get(), nullptr, nullptr, nullptr));
    if (!res.Pkey) {
        ythrow yexception() << "Failed : " << GetOpensslError();
    }

    const int actual = EVP_PKEY_id(res.Pkey.get());
    if (expected != actual) {
        ythrow yexception()
            << "Wrong type of key:"
            << " expected==" << expected << "(" << OBJ_nid2sn(expected) << "),"
            << " actual==" << actual << "(" << OBJ_nid2sn(actual) << ")";
    }

    return res;
}

class TEddsaPrivateKey: public IPrivateKey {
public:
    TEddsaPrivateKey(TStringBuf pem)
        : Key_(ParseKey(pem, EVP_PKEY_ED25519, PEM_read_bio_PrivateKey))
    {
    }

    TString Name() const override {
        return "eddsa_priv";
    }

    TString Sign(TStringBuf data) const override {
        NOpenSSL::TEvpMdCtx md_ctx;

        Y_ENSURE(1 == EVP_DigestSignInit(md_ctx, nullptr, nullptr, nullptr, Key_.Pkey.get()),
                 "failed to sign: " << GetOpensslError());

        /* Calculate the requires size for the signature by passing a nullptr buffer */
        size_t sigLen;
        Y_ENSURE(1 == EVP_DigestSign(md_ctx, nullptr, &sigLen, (unsigned char*)data.data(), data.size()),
                 "failed to sign: " << GetOpensslError());

        TString res;
        res.resize(sigLen);

        Y_ENSURE(1 == EVP_DigestSign(md_ctx, (unsigned char*)res.data(), &sigLen, (unsigned char*)data.data(), data.size()),
                 "failed to sign: " << GetOpensslError());

        return res;
    }

private:
    NPassport::NUtils::TEvpKey Key_;
};

class TEddsaPublicKey: public IPublicKey {
public:
    TEddsaPublicKey(TStringBuf pem)
        : Key_(ParseKey(pem, EVP_PKEY_ED25519, PEM_read_bio_PUBKEY))
    {
    }

    TString Name() const override {
        return "eddsa_pub";
    }

    bool Verify(TStringBuf data, TStringBuf signature) const override {
        NOpenSSL::TEvpMdCtx md_ctx;

        Y_ENSURE(1 == EVP_DigestVerifyInit(md_ctx, nullptr, nullptr, nullptr, Key_.Pkey.get()),
                 "failed to verify: " << GetOpensslError());

        const int res = EVP_DigestVerify(md_ctx, (unsigned char*)signature.data(), signature.size(), (unsigned char*)data.data(), data.size());
        Y_ENSURE(1 == res || 0 == res,
                 "failed to verify: " << GetOpensslError());
        return 1 == res;
    }

private:
    NPassport::NUtils::TEvpKey Key_;
};

class TEcdsaPrivateKey: public IPrivateKey {
public:
    TEcdsaPrivateKey(TStringBuf pem)
        : Key_(ParseKey(pem, EVP_PKEY_EC, PEM_read_bio_PrivateKey))
    {
    }

    TString Name() const override {
        return "ecdsa_priv";
    }

    TString Sign(TStringBuf data) const override {
        EC_KEY* ec = EVP_PKEY_get0_EC_KEY(Key_.Pkey.get());

        TString res;
        res.resize(ECDSA_size(ec));

        unsigned int actualSize = 0;
        Y_ENSURE(1 == ECDSA_sign(
                          0,
                          (const unsigned char*)data.data(),
                          data.size(),
                          (unsigned char*)res.data(),
                          &actualSize,
                          ec),
                 "Failed to sign data: " << GetOpensslError());
        res.resize(actualSize);

        return res;
    }

private:
    NPassport::NUtils::TEvpKey Key_;
};

class TEcdsaPublicKey: public IPublicKey {
public:
    TEcdsaPublicKey(TStringBuf pem)
        : Key_(ParseKey(pem, EVP_PKEY_EC, PEM_read_bio_PUBKEY))
    {
    }

    TString Name() const override {
        return "ecdsa_pub";
    }

    bool Verify(TStringBuf data, TStringBuf signature) const override {
        const int res = ECDSA_verify(
            0,
            (const unsigned char*)data.data(),
            data.size(),
            (const unsigned char*)signature.data(),
            signature.size(),
            EVP_PKEY_get0_EC_KEY(Key_.Pkey.get()));

        Y_ENSURE(1 == res || 0 == res,
                 "failed to verify: " << res << ": " << GetOpensslError());
        return 1 == res;
    }

private:
    NPassport::NUtils::TEvpKey Key_;
};

class TEddsaSodiumPrivateKey: public IPrivateKey {
public:
    TEddsaSodiumPrivateKey(TStringBuf key)
        : Key_(key)
    {
    }

    TString Name() const override {
        return "eddsa_priv_sod";
    }

    TString Sign(TStringBuf data) const override {
        TString res;
        res.resize(crypto_sign_ed25519_BYTES);

        Y_ENSURE(0 == crypto_sign_ed25519_detached(
                          (unsigned char*)res.data(),
                          nullptr,
                          (unsigned char*)data.data(), data.size(),
                          (unsigned char*)Key_.data()));

        return res;
    }

private:
    TString Key_;
};

class TEddsaSodiumPublicKey: public IPublicKey {
public:
    TEddsaSodiumPublicKey(TStringBuf key)
        : Key_(key)
    {
    }

    TString Name() const override {
        return "eddsa_pub_sod";
    }

    bool Verify(TStringBuf data, TStringBuf signature) const override {
        return 0 == crypto_sign_ed25519_verify_detached(
                        (unsigned char*)signature.data(),
                        (unsigned char*)data.data(),
                        data.size(),
                        (unsigned char*)Key_.data());
    }

private:
    TString Key_;
};

//////////////////////////////////////////////////

struct TRes {
    TDuration Us;
    ui32 Count = 0;

    constexpr TRes& operator+=(const TRes& o) noexcept {
        Us += o.Us;
        Count += o.Count;
        return *this;
    }
};

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
std::atomic_bool stop;

TRes Sign(const IPrivateKey& key,
          TStringBuf data) {
    TInstant start = TInstant::Now();

    ui32 count = 0;
    while (!stop.load(std::memory_order_acquire)) {
        key.Sign(data);
        ++count;
    }

    return {TInstant::Now() - start, count};
}

TRes Verify(const IPublicKey& key,
            TStringBuf data,
            TStringBuf signature) {
    TInstant start = TInstant::Now();

    ui32 count = 0;
    while (!stop.load(std::memory_order_acquire)) {
        Y_ENSURE(key.Verify(data, signature));
        ++count;
    }

    return {TInstant::Now() - start, count};
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static int IDX = 0;

template <typename Func>
void Test(ui32 threads, ui32 dataSize, TStringBuf name, Func fun) {
    stop.store(false);

    std::vector<std::future<TRes>> vec;
    for (size_t i = 0; i < threads; ++i) {
        vec.push_back(std::async(std::launch::async, fun));
    }
    std::this_thread::sleep_for(std::chrono::seconds(5));
    stop.store(true, std::memory_order_relaxed);

    TRes total;
    for (auto& a : vec) {
        total += a.get();
    }

    Cout
        << "Idx=" << RightPad(++IDX, 5)
        << " " << RightPad(name, 12)
        << " threads=" << RightPad(threads, 3)
        << " data_size=" << RightPad(dataSize, 6)
        << " per_hit=" << total.Us.MicroSeconds() / total.Count / 1000.0 << " ms"
        << Endl;
}

static const ui32 DATA_SIZES[] = {8, 16, 32, 64, 128, 512, 1024};

void TestSign(ui32 threads, const IPrivateKey& key) {
    for (ui32 ds : DATA_SIZES) {
        TString data(ds, 'a');

        Test(threads,
             ds,
             key.Name(),
             [&]() -> TRes {
                 return Sign(key, data);
             });
    }
}

void TestVerify(ui32 threads, const IPrivateKey& privKey, const IPublicKey& key) {
    for (ui32 ds : DATA_SIZES) {
        TString data(ds, 'a');
        TString signature = privKey.Sign(data);

        Test(threads,
             ds,
             key.Name(),
             [&]() -> TRes {
                 return Verify(key, data, signature);
             });
    }
}

// NOLINTNEXTLINE(bugprone-exception-escape)
int main(int, char**) {
    InitOpenSSL();

    using namespace NPassport::NUtils;
    TRsaPrivateKey rsaPrKey(ReadFile("./rsa.private"));
    TRsaPublicKey rsaPubKey(ReadFile("./rsa.public"));
    TEddsaPrivateKey edPrKey(ReadFile("./eddsa.private"));
    TEddsaPublicKey edPubKey(ReadFile("./eddsa.public"));
    TEcdsaPrivateKey ecPrKey(ReadFile("./ecdsa.private"));
    TEcdsaPublicKey ecPubKey(ReadFile("./ecdsa.public"));

    unsigned char pk[crypto_sign_ed25519_PUBLICKEYBYTES];
    unsigned char sk[crypto_sign_ed25519_SECRETKEYBYTES];
    crypto_sign_ed25519_keypair(pk, sk);
    TEddsaSodiumPrivateKey edSodPrKey(TStringBuf((char*)sk, sizeof(sk)));
    TEddsaSodiumPublicKey edSodPubKey(TStringBuf((char*)pk, sizeof(pk)));

    ui32 thr[] = {1, 2, 4, 8, 16};

    Cout << "Signing" << Endl << Endl;
    for (ui32 threads : thr) {
        TestSign(threads, rsaPrKey);
    }
    Cout << Endl;
    for (ui32 threads : thr) {
        TestSign(threads, edPrKey);
    }
    Cout << Endl;
    for (ui32 threads : thr) {
        TestSign(threads, ecPrKey);
    }
    Cout << Endl;
    for (ui32 threads : thr) {
        TestSign(threads, edSodPrKey);
    }
    Cout << Endl;

    IDX = 0;
    Cout << "Verifying" << Endl << Endl;
    for (ui32 threads : thr) {
        TestVerify(threads, rsaPrKey, rsaPubKey);
    }
    Cout << Endl;
    for (ui32 threads : thr) {
        TestVerify(threads, edPrKey, edPubKey);
    }
    Cout << Endl;
    for (ui32 threads : thr) {
        TestVerify(threads, ecPrKey, ecPubKey);
    }
    Cout << Endl;
    for (ui32 threads : thr) {
        TestVerify(threads, edSodPrKey, edSodPubKey);
    }
    Cout << Endl;

    return 0;
}
