#include "key.h"

#include <util/generic/yexception.h>

namespace {
    constexpr std::string_view BeginPublicKey = "-----BEGIN PUBLIC KEY-----";
    constexpr std::string_view EndPublicKey = "-----END PUBLIC KEY-----";

    std::string PreparePEM(const std::string& pem) {
        constexpr std::string_view nextline = "\n";
        std::string result;
        if (!pem.starts_with(BeginPublicKey)) {
            result.append(BeginPublicKey);
            if (!pem.starts_with(nextline)) {
                result.append(nextline);
            }
            result.append(pem);
        } else {
            result = pem;
        }
        if (!result.ends_with(EndPublicKey)) {
            if (!result.ends_with(nextline)) {
                result.append(nextline);
            }
            result.append(EndPublicKey);
        }
        return result;
    }
}

NDrive::TJwk::TJwk(const TJwkDescription& description)
    : KeyId(description.KeyId)
{
    if (description.Algorithm == "RS256") {
        auto pem = PreparePEM(description.PEM);
        Algorithm.emplace<jwt::algorithm::rs256>(pem);
    } else if (!description.Algorithm.empty()) {
        throw yexception() << "unsupported jwk algorithm: " << description.Algorithm;
    }
}

void ValidateToken(const std::string& data, const std::string& signature, const NDrive::TJwk& key) {
    std::visit([&](const auto& alg) {
        alg.verify(data, signature);
    }, key.Algorithm);
}

void NDrive::ValidateToken(const jwt::decoded_jwt& token, const TJwk& key) {
    Y_ENSURE(key.KeyId == token.get_key_id(), "unsupported key_id: " << token.get_key_id());
    const auto& data = token.get_header_base64() + "." + token.get_payload_base64();
    const auto& signature = token.get_signature();
    ::ValidateToken(data, signature, key);
}

void NDrive::ValidateToken(const jwt::decoded_jwt& token, const TJwks& keys) {
    if (keys.empty()) {
        throw yexception() << "no keys provided";
    }

    const auto& data = token.get_header_base64() + "." + token.get_payload_base64();
    const auto& signature = token.get_signature();

    for (auto&& key : keys) {
        if (token.get_key_id() != key.KeyId) {
            continue;
        }
        ::ValidateToken(data, signature, key);
        return;
    }
    throw yexception() << "no key found for key_id " << token.get_key_id();
}
