#include "payload_encryption.h"

namespace yxiva::mobile {

const unsigned int SALT_SIZE = 16; // Standard salt size stated by standard

// Creates "info" for use in hkdf derivation as described in payload encryption standard
std::vector<unsigned char> create_hkdf_info(
    const char what[],
    const std::vector<unsigned char>& client_key,
    const std::vector<unsigned char>& server_key)
{
    static constexpr char ce[] = "Content-Encoding: ";
    static constexpr char p[] = "P-256";
    static const auto cel = std::char_traits<char>::length(ce);
    static const auto pl = std::char_traits<char>::length(p);
    auto wl = std::char_traits<char>::length(what);
    auto ck_sz = static_cast<unsigned short>(client_key.size());
    auto sk_sz = static_cast<unsigned short>(server_key.size());
    size_t len = cel + wl + pl + ck_sz + sk_sz + 7;

    std::vector<unsigned char> info;
    info.reserve(len);

    info.insert(info.end(), std::begin(ce), std::begin(ce) + cel);
    info.insert(info.end(), &what[0], what + wl);
    info.push_back(0);
    info.insert(info.end(), std::begin(p), std::begin(p) + pl);
    info.push_back(0);
    info.push_back(static_cast<unsigned char>(ck_sz >> 8));
    info.push_back(static_cast<unsigned char>(ck_sz & 0xFF));
    info.insert(info.end(), client_key.cbegin(), client_key.cend());
    info.push_back(static_cast<unsigned char>(sk_sz >> 8));
    info.push_back(static_cast<unsigned char>(sk_sz & 0xFF));
    info.insert(info.end(), server_key.cbegin(), server_key.cend());
    // Not specified for create_info, but will need it in hkdf
    info.push_back(1);

    return info;
}

inline std::vector<unsigned char> base64_urlsafe_decode(const string_view& b64)
{
    auto dec = yplatform::base64_urlsafe_decode(b64.begin(), b64.end());
    return std::vector<unsigned char>(dec.begin(), dec.end());
}

encrypted_payload encrypt_payload(
    const string& payload,
    const server_keys& keys,
    const string_view& client_auth_secret,
    const string_view& client_pub_key)
{
    static const char ai[] = "Content-Encoding: auth\0\1";
    static const std::vector<unsigned char> auth_info(
        std::begin(ai), std::begin(ai) + sizeof(ai) - 1);
    std::vector<unsigned char> salt = ec_crypto::rand_vector(SALT_SIZE);

    auto client_key = base64_urlsafe_decode(client_pub_key);
    auto client_auth = base64_urlsafe_decode(client_auth_secret);
    if (client_key.empty() || client_auth.empty())
        throw std::runtime_error("invalid client key/auth");

    auto client_key_evp = ec_crypto::evp_from_public_key(client_key, keys.ecgrp.get());
    auto secret = ec_crypto::derive_secret(keys.keypair, client_key_evp);

    auto prk = ec_crypto::hkdf(client_auth, secret, auth_info, 32);

    auto cont_enc_key_info = create_hkdf_info("aesgcm", client_key, keys.public_key);
    auto cont_enc_key = ec_crypto::hkdf(salt, prk, cont_enc_key_info, 16);

    auto nonce_info = create_hkdf_info("nonce", client_key, keys.public_key);
    auto nonce = ec_crypto::hkdf(salt, prk, nonce_info, 12);

    auto payload_encrypted = ec_crypto::aes_gcm_128_with_padding(payload, cont_enc_key, nonce);
    auto salt_b64 = yplatform::base64_urlsafe_encode(salt.begin(), salt.end());

    return encrypted_payload{ std::move(payload_encrypted),
                              string(salt_b64.begin(), salt_b64.end()) };
}

}
