#include "rsa.h"

#include "hash.h"

#include <passport/infra/libs/cpp/utils/string/coder.h>

#include <contrib/libs/openssl/include/openssl/bio.h>
#include <contrib/libs/openssl/include/openssl/err.h>
#include <contrib/libs/openssl/include/openssl/pem.h>
#include <contrib/libs/openssl/include/openssl/x509.h>

#include <library/cpp/openssl/holders/holder.h>

#include <util/generic/yexception.h>

namespace NPassport::NUtils {
    namespace {
        struct TErrCleaner {
            TErrCleaner() {
                ERR_clear_error();
            }

            ~TErrCleaner() {
                ERR_clear_error();
            }
        };

        TStringBuf GetOpensslError() {
            return ERR_reason_error_string(ERR_get_error());
        }
    }

    TRsaPublicEvp TRsaPublicEvp::FromPem(TStringBuf base64EncodedKey) {
        TErrCleaner cleaner;

        std::unique_ptr<BIO, decltype(&BIO_free)> bio(
            BIO_new_mem_buf(base64EncodedKey.data(), base64EncodedKey.size()),
            BIO_free);
        if (!bio) {
            throw std::bad_alloc();
        }

        TRsaPublicEvp res;

        res.Pkey.reset(PEM_read_bio_PUBKEY(bio.get(), nullptr, nullptr, nullptr));
        if (!res.Pkey) {
            throw yexception() << "Failed PEM_read_bio_PUBKEY: " << GetOpensslError();
        }

        const int type = EVP_PKEY_id(res.Pkey.get());
        if (EVP_PKEY_RSA != type) {
            throw yexception()
                << "Wrong type of key:"
                << " expected==" << EVP_PKEY_RSA << "(" << OBJ_nid2sn(EVP_PKEY_RSA) << "),"
                << " actual==" << type << "(" << OBJ_nid2sn(type) << ")";
        }

        return res;
    }

    TRsaPublicEvp::TResult TRsaPublicEvp::VerifyWithSha256(const TStringBuf data, const TStringBuf sign) const {
        TErrCleaner cleaner;
        const TString sha256 = TCrypto::Sha256(data);

        TResult res;
        res.IsSuccess = 1 == RSA_verify(
                                 NID_sha256,
                                 (const unsigned char*)sha256.data(),
                                 sha256.size(),
                                 (const unsigned char*)sign.data(),
                                 sign.size(),
                                 EVP_PKEY_get0_RSA(Pkey.get()));
        res.Details = GetOpensslError();

        return res;
    }

    TRsaPrivateEvp TRsaPrivateEvp::FromPem(TStringBuf base64EncodedKey) {
        TErrCleaner cleaner;

        std::unique_ptr<BIO, decltype(&BIO_free)> bio(
            BIO_new_mem_buf(base64EncodedKey.data(), base64EncodedKey.size()),
            BIO_free);
        if (!bio) {
            throw std::bad_alloc();
        }

        TRsaPrivateEvp res;

        res.Pkey.reset(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
        if (!res.Pkey) {
            throw yexception() << "Failed PEM_read_bio_PrivateKey: " << GetOpensslError();
        }

        const int type = EVP_PKEY_id(res.Pkey.get());
        if (EVP_PKEY_RSA != type) {
            throw yexception()
                << "Wrong type of key:"
                << " expected==" << EVP_PKEY_RSA << "(" << OBJ_nid2sn(EVP_PKEY_RSA) << "),"
                << " actual==" << type << "(" << OBJ_nid2sn(type) << ")";
        }

        return res;
    }

    TString TRsaPrivateEvp::SignWithSha256(const TStringBuf data) const {
        TErrCleaner cleaner;
        const TString sha256 = TCrypto::Sha256(data);

        RSA* rsa = EVP_PKEY_get0_RSA(Pkey.get());

        TString res;
        res.resize(RSA_size(rsa));

        unsigned int actualSize = 0;
        Y_ENSURE(1 == RSA_sign(
                          NID_sha256,
                          (const unsigned char*)sha256.data(),
                          sha256.size(),
                          (unsigned char*)res.data(),
                          &actualSize,
                          rsa),
                 "Failed to sign data: " << GetOpensslError());

        return res;
    }

    TString TRsaPrivateEvp::DecryptWithModes(TStringBuf encrypted,
                                             const TControls& controls) const {
        TErrCleaner cleaner;

        using TEvpPkeyCtx = NOpenSSL::THolder<EVP_PKEY_CTX,
                                              EVP_PKEY_CTX_new,
                                              EVP_PKEY_CTX_free,
                                              EVP_PKEY*,
                                              ENGINE*>;

        TEvpPkeyCtx ctx(Pkey.get(), nullptr);

        Y_ENSURE(1 == EVP_PKEY_decrypt_init(ctx),
                 "Failed to init decrypt: " << GetOpensslError());

        for (const auto& [key, value] : controls) {
            Y_ENSURE(1 == EVP_PKEY_CTX_ctrl_str(ctx, key.c_str(), value.c_str()),
                     "Failed to init set control '" << key << "': " << GetOpensslError());
        }

        size_t plaintextSize = 0;

        Y_ENSURE(1 == EVP_PKEY_decrypt(ctx,
                                       nullptr,
                                       &plaintextSize,
                                       (const unsigned char*)encrypted.data(),
                                       encrypted.size()),
                 "can't determine plaintext size: " << GetOpensslError());

        TString res(plaintextSize, 0);
        Y_ENSURE(1 == EVP_PKEY_decrypt(ctx,
                                       (unsigned char*)res.data(),
                                       &plaintextSize,
                                       (const unsigned char*)encrypted.data(),
                                       encrypted.size()),
                 "Failed to decrypt: " << GetOpensslError());

        res.resize(plaintextSize);

        return res;
    }
}

// for tests
template <>
void Out<NPassport::NUtils::TRsaPublicEvp::TResult>(IOutputStream& o,
                                                    const NPassport::NUtils::TRsaPublicEvp::TResult& val) {
    o << "isSuccess==" << val.IsSuccess << ";details==" << val.Details;
}
