#include "cryptography.h"

#include <yandex_io/libs/base/utils.h>
#include <yandex_io/libs/errno/errno_exception.h>
#include <yandex_io/libs/logging/logging.h>

#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/hmac.h>
#include <openssl/pem.h>
#include <openssl/rand.h>
#include <openssl/rsa.h>

#include <util/generic/scope.h>

#include <stdexcept>

using namespace quasar;

namespace {

    struct EVP_CIPHER_CTX_RaiiHelper {
        EVP_CIPHER_CTX* ctx;

        EVP_CIPHER_CTX_RaiiHelper() {
            ctx = EVP_CIPHER_CTX_new();
        }

        ~EVP_CIPHER_CTX_RaiiHelper() {
            EVP_CIPHER_CTX_free(ctx);
        }
    };

    struct HMAC_CTX_RaiiHelper {
        HMAC_CTX* ctx;

        HMAC_CTX_RaiiHelper() {
            ctx = HMAC_CTX_new();
        }

        ~HMAC_CTX_RaiiHelper() {
            HMAC_CTX_free(ctx);
        }
    };

    std::string getErrorMessage() {
        ERR_load_crypto_strings();
        char buf[256];
        return ERR_error_string(ERR_get_error(), buf);
    }

    void checkOpensslError(bool rc) {
        if (!rc) {
            throw std::runtime_error("openssl routine failed: " + getErrorMessage());
        }
    }

    std::unique_ptr<RSA, Cryptography::RSADeleter> rsaFromPrivatePEM(std::string_view privateKey) {
        BIO* bufio = BIO_new_mem_buf(static_cast<const void*>(privateKey.data()), privateKey.size());

        RSA* result = PEM_read_bio_RSAPrivateKey(bufio, nullptr, nullptr, nullptr);
        if (!result) {
            BIO_free(bufio);
            throw std::runtime_error("Cannot read encryption key: " + getErrorMessage());
        }

        BIO_free(bufio);

        return std::unique_ptr<RSA, Cryptography::RSADeleter>(result);
    }

    std::unique_ptr<RSA, Cryptography::RSADeleter> rsaFromPublicPEM(std::string_view encryptionKey) {
        BIO* bufio = BIO_new_mem_buf(static_cast<const void*>(encryptionKey.data()), encryptionKey.size());

        RSA* result = PEM_read_bio_RSA_PUBKEY(bufio, nullptr, nullptr, nullptr);
        if (!result) {
            BIO_free(bufio);
            throw std::runtime_error("Cannot read encryption key: " + getErrorMessage());
        }

        BIO_free(bufio);

        return std::unique_ptr<RSA, Cryptography::RSADeleter>(result);
    }

    std::string rsaToPublicPEM(const std::unique_ptr<RSA, Cryptography::RSADeleter>& rsa) {
        BIO* mem = BIO_new(BIO_s_mem());

        Y_DEFER {
            BIO_free(mem);
        };

        int result = PEM_write_bio_RSA_PUBKEY(mem, rsa.get());
        if (!result) {
            throw std::runtime_error("Cannot write public key: " + getErrorMessage());
        }

        char* pem;
        long size = BIO_get_mem_data(mem, &pem);
        if (size < 0) {
            throw std::runtime_error("Cannot write public key: " + getErrorMessage());
        }

        return std::string(pem, size);
    }
} // namespace

Cryptography::Padding Cryptography::parsePadding(int value) {
    switch (value) {
        case RSA_PKCS1_PADDING:
            return Padding::RSA_PKCS1;
        case RSA_PKCS1_OAEP_PADDING:
            return Padding::RSA_PKCS1_OAEP;
        default:
            Y_FAIL("Undefined padding");
    }
}

int Cryptography::paddingValue(Padding padding) {
    switch (padding) {
        case Padding::RSA_PKCS1:
            return RSA_PKCS1_PADDING;
        case Padding::RSA_PKCS1_OAEP:
            return RSA_PKCS1_OAEP_PADDING;
    }
}

Cryptography::KeyPair Cryptography::KeyPair::fromFiles(const std::string& publicKeyPath, const std::string& privateKeyPath) {
    return {
        getFileContent(publicKeyPath),
        getFileContent(privateKeyPath),
    };
}

Cryptography::KeyPair Cryptography::KeyPair::fromPrivateKeyFile(const std::string& privateKeyPath) {
    std::string privateKey = getFileContent(privateKeyPath);
    std::string publicKey = rsaToPublicPEM(rsaFromPrivatePEM(privateKey));

    return {
        publicKey,
        privateKey,
    };
}

void Cryptography::RSADeleter::operator()(RSA* rsa) {
    RSA_free(rsa);
}

void Cryptography::loadPublicKey(const std::string& fileName) {
    setPublicKey(getFileContent(fileName));
}

void Cryptography::loadPrivateKey(const std::string& fileName) {
    setPrivateKey(getFileContent(fileName));
}

void Cryptography::setKeyPair(const KeyPair& keys) {
    setPublicKey(keys.publicKey);
    setPrivateKey(keys.privateKey);
}

void Cryptography::setPublicKey(std::string_view encryptionKey) {
    publicRSA_ = rsaFromPublicPEM(encryptionKey);
}

void Cryptography::setPrivateKey(std::string_view privateKey) {
    privateRSA_ = rsaFromPrivatePEM(privateKey);
}

std::string Cryptography::encrypt(std::string_view data, Padding padding) const {
    if (!publicRSA_) {
        throw std::runtime_error("Cannot encrypt data. Encryption key is not set");
    }

    std::string result;
    result.resize(RSA_size(publicRSA_.get()));
    int resultLength = RSA_public_encrypt(data.size(),
                                          reinterpret_cast<const unsigned char*>(data.data()),
                                          reinterpret_cast<unsigned char*>(result.data()),
                                          publicRSA_.get(),
                                          Cryptography::paddingValue(padding));

    if (resultLength < 0) {
        throw std::runtime_error("Encryption error: " + getErrorMessage());
    }

    result.resize(result.length());

    return result;
}

std::string Cryptography::decrypt(std::string_view data, Padding padding) const {
    if (!privateRSA_) {
        throw std::runtime_error("Cannot decrypt data. Decryption key is not set");
    }

    std::string result;
    result.resize(data.length());
    int resultLength = RSA_private_decrypt(data.size(),
                                           reinterpret_cast<const unsigned char*>(data.data()),
                                           reinterpret_cast<unsigned char*>(result.data()),
                                           privateRSA_.get(),
                                           Cryptography::paddingValue(padding));

    if (resultLength < 0) {
        throw std::runtime_error("Decryption error: " + getErrorMessage());
    }

    result.resize(resultLength);

    return result;
}

std::string Cryptography::decryptAES(std::string_view iv, std::string_view key, std::string_view data) {
    EVP_CIPHER_CTX_RaiiHelper ctx;
    checkOpensslError(EVP_DecryptInit_ex(ctx.ctx, EVP_aes_128_cbc(), nullptr,
                                         reinterpret_cast<const unsigned char*>(key.data()),
                                         reinterpret_cast<const unsigned char*>(iv.data())));
    char plaintext[128];
    int plaintextLength = 0;
    int len;
    checkOpensslError(EVP_DecryptUpdate(ctx.ctx,
                                        reinterpret_cast<unsigned char*>(plaintext),
                                        &len,
                                        reinterpret_cast<const unsigned char*>(data.data()),
                                        data.length()));

    plaintextLength += len;

    checkOpensslError(EVP_DecryptFinal_ex(ctx.ctx, reinterpret_cast<unsigned char*>(plaintext + len), &len));
    plaintextLength += len;
    std::string result(plaintext, plaintextLength);
    return result;
}

std::string Cryptography::sign(std::string_view data) const {
    unsigned char sig_buf[4096];

    unsigned int sig_len = sizeof(sig_buf);

    unsigned char dataHash[SHA256_DIGEST_LENGTH];
    {
        SHA256_CTX ctx;
        SHA256_Init(&ctx);
        SHA256_Update(&ctx, data.data(), data.length());
        SHA256_Final(dataHash, &ctx);
    }

    int ret = RSA_sign(NID_sha256, dataHash, SHA256_DIGEST_LENGTH, sig_buf, &sig_len, privateRSA_.get());

    if (ret != 1) {
        throw std::runtime_error("Signature error: " + getErrorMessage());
    }

    return std::string(reinterpret_cast<const char*>(sig_buf), sig_len);
}

bool Cryptography::checkSignature(std::string_view data, std::string_view sign) const {
    unsigned char dataHash[SHA256_DIGEST_LENGTH];
    {
        SHA256_CTX ctx;
        SHA256_Init(&ctx);
        SHA256_Update(&ctx, data.data(), data.length());
        SHA256_Final(dataHash, &ctx);
    }

    int ret = RSA_verify(NID_sha256,
                         dataHash, SHA256_DIGEST_LENGTH,
                         reinterpret_cast<const unsigned char*>(sign.data()), sign.size(),
                         publicRSA_.get());
    return (ret == 1);
}

std::string Cryptography::hashWithHMAC_SHA256(std::string_view data, std::string_view key) {
    HMAC_CTX_RaiiHelper ctx;
    unsigned int result_len = SHA256_DIGEST_LENGTH;
    std::string result;
    result.resize(result_len);
    checkOpensslError(HMAC_Init_ex(ctx.ctx, key.data(), key.length(), EVP_sha256(), nullptr));
    checkOpensslError(HMAC_Update(ctx.ctx, reinterpret_cast<const unsigned char*>(data.data()), data.size()));
    checkOpensslError(HMAC_Final(ctx.ctx, reinterpret_cast<unsigned char*>(result.data()), &result_len));
    return result;
}

// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
std::string Cryptography::generateAESKeyString() const {
    constexpr int AES_FULL_RANDOM_KEY_LENGTH = 32;
    std::string key;
    key.resize(AES_FULL_RANDOM_KEY_LENGTH);
    if (!RAND_bytes(reinterpret_cast<unsigned char*>(key.data()), AES_FULL_RANDOM_KEY_LENGTH)) {
        throw std::runtime_error("could not generate random string: " + getErrorMessage());
    }
    return key;
}
