#include <yxiva/core/ec_crypto.h>
#include <openssl/bn.h>
#include <openssl/obj_mac.h>
#include <openssl/pem.h>
#include <openssl/hmac.h>
#include <openssl/err.h>
#include <functional>
#include <string>

namespace yxiva { namespace ec_crypto {

auto get_order(const EC_GROUP* group)
{
    return std::unique_ptr<const BIGNUM, std::function<void(const BIGNUM*)>>(
        EC_GROUP_get0_order(group), [](auto) {});
}

using file_ptr = std::unique_ptr<FILE, void (*)(FILE*)>;

file_ptr safe_fopen(const char* filename, const char* mode)
{
    auto safe_close = [](FILE* f) {
        if (f) fclose(f);
    };
    return file_ptr(fopen(filename, mode), safe_close);
}

evp_pkey_ptr read_pem(const std::string& filename)
{
    auto pk_file = safe_fopen(filename.c_str(), "r");
    if (!pk_file) throw std::runtime_error("failed to open pem file");
    auto key = make_ptr(PEM_read_PrivateKey(pk_file.get(), nullptr, nullptr, nullptr));
    if (!key) throw std::runtime_error("failed to read key from pem file");
    return key;
}

evp_pkey_ptr read_pem_buf(const std::string& buf)
{
    auto bio = make_ptr(BIO_new(BIO_s_mem()));
    if (!bio) throw std::runtime_error("failed to create openssl BIO from memory");

    BIO_write(bio.get(), buf.data(), buf.size());

    auto key = make_ptr(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
    if (!key) throw std::runtime_error("failed to read key from buffer");

    return key;
}

evp_pkey_ptr generate_crypto_keys()
{
    int eccgrp = OBJ_txt2nid("prime256v1");
    if (!eccgrp) throw std::runtime_error("error getting nid");
    ec_key_ptr myecc = make_ptr(EC_KEY_new_by_curve_name(eccgrp));
    if (!myecc) throw std::runtime_error("error creating ECC key");

    if (!(EC_KEY_generate_key(myecc.get())))
        throw std::runtime_error("error generating the ECC key");

    evp_pkey_ptr pkey = make_ptr(EVP_PKEY_new());
    if (!pkey) throw std::runtime_error("error creating EVP key ptr");

    if (!EVP_PKEY_assign_EC_KEY(pkey.get(), myecc.get()))
        throw std::runtime_error("error writing EVP");
    myecc.release();
    return pkey;
}

std::vector<unsigned char> get_public_key(const ec_key_ptr& key)
{
    const EC_POINT* ecp = EC_KEY_get0_public_key(key.get());
    const EC_GROUP* ecgrp = EC_KEY_get0_group(key.get());
    auto pub_key_sz =
        EC_POINT_point2oct(ecgrp, ecp, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0, nullptr);
    if (!pub_key_sz) throw std::runtime_error("invalid public key size returned");

    std::vector<unsigned char> pub_key_bin(pub_key_sz);
    if (EC_POINT_point2oct(
            ecgrp,
            ecp,
            POINT_CONVERSION_UNCOMPRESSED,
            pub_key_bin.data(),
            static_cast<int>(pub_key_bin.size()),
            nullptr) != pub_key_sz)
        throw std::runtime_error("invalid public key data returned");
    return pub_key_bin;
}

evp_pkey_ptr evp_from_public_key(const std::vector<unsigned char>& pk, const EC_GROUP* ecgrp)
{
    int eccgrp_ind = OBJ_txt2nid("prime256v1");
    if (!eccgrp_ind) throw std::runtime_error("error getting nid");
    ec_point_ptr ecpoint = make_ptr(EC_POINT_new(ecgrp));
    if (!ecpoint) throw std::runtime_error("failed to create point");
    if (!EC_POINT_oct2point(ecgrp, ecpoint.get(), pk.data(), pk.size(), nullptr))
        throw std::runtime_error("failed conversion oct2point");

    ec_key_ptr ec_key = make_ptr(EC_KEY_new_by_curve_name(eccgrp_ind));
    if (!ec_key) throw std::runtime_error("failed to create key");
    if (!EC_KEY_set_public_key(ec_key.get(), ecpoint.get()))
        throw std::runtime_error("failed to set public key");

    evp_pkey_ptr pkey = make_ptr(EVP_PKEY_new());
    if (!pkey) throw std::runtime_error("error creating EVP key ptr");
    if (!EVP_PKEY_assign_EC_KEY(pkey.get(), ec_key.get()))
        throw std::runtime_error("error writing EVP");
    ec_key.release();
    return pkey;
}

std::vector<unsigned char> derive_secret(
    const evp_pkey_ptr& server_key,
    const evp_pkey_ptr& client_key)
{
    evp_pkey_ctx_ptr ctx = make_ptr(EVP_PKEY_CTX_new(server_key.get(), nullptr));
    if (!ctx) throw std::runtime_error("error creating EVP context");
    if (EVP_PKEY_derive_init(ctx.get()) <= 0)
        throw std::runtime_error("error initiating derivation");
    if (EVP_PKEY_derive_set_peer(ctx.get(), client_key.get()) <= 0)
        throw std::runtime_error("error initiating derivation");
    size_t len;
    if (EVP_PKEY_derive(ctx.get(), nullptr, &len) <= 0)
        throw std::runtime_error("error getting secret size");
    std::vector<unsigned char> secret(len);
    if (EVP_PKEY_derive(ctx.get(), secret.data(), &len) <= 0)
        throw std::runtime_error("error deriving key");
    return secret;
}

std::vector<unsigned char> hkdf(
    const std::vector<unsigned char>& salt,
    const std::vector<unsigned char>& key,
    const std::vector<unsigned char>& info,
    size_t length)
{
    auto sha256 = EVP_sha256();
    auto hash_size = EVP_MD_size(sha256);

    std::vector<unsigned char> k2(hash_size);
    std::vector<unsigned char> k_out(hash_size);
    unsigned int len = hash_size;
    if (!HMAC(
            sha256,
            salt.data(),
            static_cast<int>(salt.size()),
            key.data(),
            key.size(),
            k2.data(),
            &len))
        throw std::runtime_error("error in hkdf first hmac sha256");
    if (len != k2.size()) k2.resize(len);

    len = hash_size;
    if (!HMAC(
            sha256,
            k2.data(),
            static_cast<int>(k2.size()),
            info.data(),
            info.size(),
            k_out.data(),
            &len))
        throw std::runtime_error("error in hkdf second hmac sha256");
    k_out.resize(length);
    return k_out;
}

std::string aes_gcm_128_with_padding(
    const std::string& data,
    const std::vector<unsigned char>& key,
    const std::vector<unsigned char>& iv,
    unsigned short padding)
{
    if (padding != 0) throw std::runtime_error("non-zero padding is not supported in this version");
    unsigned char padding_bytes[2] = { 0, 0 };
    evp_cipher_ctx_ptr ctx = make_ptr(EVP_CIPHER_CTX_new());
    if (!ctx) throw std::runtime_error("failed to create cipher context");
    if (!EVP_EncryptInit_ex(ctx.get(), EVP_aes_128_gcm(), nullptr, key.data(), iv.data()))
        throw std::runtime_error("failed to initialize encryption");
    auto block_size = EVP_CIPHER_CTX_block_size(ctx.get());
    const int taglen = 16; // Expected tag length is 16
    int outlen = static_cast<int>(data.size()) + 2 * block_size - 1 + taglen + padding + 2;
    std::string encrypted(outlen, '\0');
    auto encrypted_data = reinterpret_cast<unsigned char*>(&encrypted[0]);

    if (!EVP_EncryptUpdate(
            ctx.get(), encrypted_data, &outlen, padding_bytes, sizeof(padding_bytes)))
    {
        throw std::runtime_error("failed to encrypt data");
    }
    int total_len = outlen;
    if (!EVP_EncryptUpdate(
            ctx.get(),
            encrypted_data + total_len,
            &outlen,
            reinterpret_cast<const unsigned char*>(data.data()),
            static_cast<int>(data.size())))
    {
        throw std::runtime_error("failed to encrypt data");
    }
    total_len += outlen;
    outlen = static_cast<int>(encrypted.size()) - outlen;
    if (!EVP_EncryptFinal_ex(ctx.get(), encrypted_data + total_len, &outlen))
    {
        throw std::runtime_error("failed to finalize encryption");
    }
    total_len += outlen;
    if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, taglen, encrypted_data + total_len))
    {
        throw std::runtime_error("failed to write cipher tag");
    }
    total_len += taglen;
    encrypted.resize(total_len);
    return encrypted;
}

std::vector<unsigned char> sign(const void* data, size_t len, const evp_pkey_ptr& key)
{
    auto hash = EVP_sha256();

    evp_md_ctx_ptr md_ctx = make_ptr(EVP_MD_CTX_new());
    if (!md_ctx)
    {
        throw std::runtime_error("error creating EVP_MD_CTX");
    }

    if (!EVP_DigestInit(md_ctx.get(), hash)) throw std::runtime_error("failed to init hash");
    if (!EVP_DigestUpdate(md_ctx.get(), data, len))
        throw std::runtime_error("failed to compute hash");

    auto hash_size = EVP_MD_size(hash);
    std::vector<unsigned char> msg_digest(hash_size);
    if (!EVP_DigestFinal(md_ctx.get(), msg_digest.data(), nullptr))
        throw std::runtime_error("failed to retrieve hash");

    auto ec_key = evp_to_ec(key);
    // Expected size of signature depends on EC group order,
    // as r and s are modulo of order and can't be larger.
    const EC_GROUP* ecgrp = EC_KEY_get0_group(ec_key.get());
    auto order = get_order(ecgrp);
    auto component_size = BN_num_bytes(order.get());

    auto sig = make_ptr(
        ECDSA_do_sign(msg_digest.data(), static_cast<int>(msg_digest.size()), ec_key.get()));
    if (!sig) throw std::runtime_error("failed to sign");

    const BIGNUM* r;
    const BIGNUM* s;
    ECDSA_SIG_get0(sig.get(), &r, &s);
    std::vector<unsigned char> sign(2 * component_size);
    if (BN_bn2binpad(r, sign.data(), component_size) == -1)
        throw std::runtime_error("failed to write sign component r to buffer");
    if (BN_bn2binpad(s, sign.data() + component_size, component_size) == -1)
        throw std::runtime_error("failed to write sign component s to buffer");

    return sign;
}

bool verify_sign(
    const void* data,
    size_t len,
    const std::vector<unsigned char>& sig_raw,
    const evp_pkey_ptr& key)
{
    auto hash = EVP_sha256();
    evp_md_ctx_ptr md_ctx = make_ptr(EVP_MD_CTX_new());
    if (!md_ctx)
    {
        throw std::runtime_error("error creating EVP_MD_CTX");
    }

    if (!EVP_DigestInit(md_ctx.get(), hash)) throw std::runtime_error("failed to init hash");
    if (!EVP_DigestUpdate(md_ctx.get(), data, len))
        throw std::runtime_error("failed to compute hash");

    auto hash_size = EVP_MD_size(hash);
    std::vector<unsigned char> msg_digest(hash_size);
    if (!EVP_DigestFinal(md_ctx.get(), msg_digest.data(), nullptr))
        throw std::runtime_error("failed to retrieve hash");

    auto ec_key = ec_crypto::evp_to_ec(key);

    auto sig = ec_crypto::make_ptr(ECDSA_SIG_new());
    if (!sig) throw std::runtime_error("failed to create sign object");

    auto r_size = sig_raw.size() / 2;
    auto s_size = sig_raw.size() - r_size;
    ECDSA_SIG_set0(
        sig.get(),
        BN_bin2bn(sig_raw.data(), r_size, nullptr),
        BN_bin2bn(sig_raw.data() + r_size, s_size, nullptr));

    int res = ECDSA_do_verify(msg_digest.data(), msg_digest.size(), sig.get(), ec_key.get());
    if (res == 1) return true;
    if (res == -1) throw std::runtime_error("error in function ECDSA_do_verify");
    return false;
}

}}
