#include "sshkey.h"

#include <passport/infra/libs/cpp/utils/string/coder.h>
#include <passport/infra/libs/cpp/utils/string/split.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <contrib/libs/openssl/include/openssl/bn.h>
#include <contrib/libs/openssl/include/openssl/err.h>
#include <contrib/libs/openssl/include/openssl/evp.h>
#include <contrib/libs/openssl/include/openssl/pem.h>

#include <library/cpp/openssl/holders/evp.h>

#include <array>
#include <iomanip>
#include <sstream>
#include <vector>

namespace NPassport::NUtils {
    static const std::array<char, 11> AGENT_RSA_PREFIX = {0x00, 0x00, 0x00, 0x07, 's', 's', 'h', '-', 'r', 's', 'a'};

    TSshPublicKey::TSshPublicKey(const TStringBuf formatedKey) {
        InitKey(Parse(formatedKey));
    }

    bool TSshPublicKey::Verify(const TStringBuf sign,
                               const TStringBuf rawString,
                               EMode mode,
                               TString& err) const {
        NOpenSSL::TEvpMdCtx md_ctx;
        if (EVP_DigestInit_ex(md_ctx, EVP_sha1(), nullptr) != 1) {
            err.assign("EVP_DigestInit_ex error: ").append(ERR_reason_error_string(ERR_get_error()));
            return false;
        }

        EVP_PKEY_CTX* pctx = nullptr;
        if (EVP_DigestVerifyInit(md_ctx, &pctx, EVP_sha1(), nullptr, Pkey.get()) != 1) {
            err.assign("EVP_DigestVerifyInit error: ").append(ERR_reason_error_string(ERR_get_error()));
            return false;
        }

        if (EMode::RSA_PSS == mode) {
            if (EVP_PKEY_CTX_ctrl_str(pctx, "rsa_padding_mode", "pss") != 1) {
                err.assign("EVP_PKEY_CTX_ctrl_str error: ").append(ERR_reason_error_string(ERR_get_error()));
                return false;
            }
        }

        if (EVP_DigestVerifyUpdate(md_ctx, rawString.data(), rawString.size()) != 1) {
            err.assign("EVP_DigestVerifyUpdate error: ").append(ERR_reason_error_string(ERR_get_error()));
            return false;
        }

        return EVP_DigestVerifyFinal(md_ctx, (unsigned char*)sign.data(), sign.size()) == 1;
    }

    TString TSshPublicKey::GetSshAgentSign(const TStringBuf sign) {
        if (sign.size() < 16) {
            return {};
        }

        if (sign.substr(0, AGENT_RSA_PREFIX.size()) != TStringBuf(AGENT_RSA_PREFIX.begin(), AGENT_RSA_PREFIX.end())) {
            return {};
        }

        // Strip prefix and lenth of sign
        const size_t LEN_SIZE = 4;
        ui32 len = 0;
        for (size_t idx = 0; idx < LEN_SIZE; ++idx) {
            len |= ui32(sign[AGENT_RSA_PREFIX.size() + LEN_SIZE - 1 - idx]) << 8 * idx;
        }

        if (AGENT_RSA_PREFIX.size() + LEN_SIZE + len != sign.size()) {
            return {};
        }

        return TString(sign.substr(AGENT_RSA_PREFIX.size() + LEN_SIZE));
    }

    /* we support only rsa public keys */
    static const TString PREFIX("ssh-rsa");
    TSshPublicKey::TNums TSshPublicKey::Parse(const TStringBuf formKey) {
        TString formatedKey = ReplaceAny(formKey, "\n", "");
        std::vector<TStringBuf> tokens = ToVector<TStringBuf>(formatedKey, ' ', 3);

        Y_ENSURE_EX(tokens.size() >= 2,
                    TMalformedException() << "SshPublicKey ill-formed: tokens: " << formatedKey);
        Y_ENSURE_EX(tokens[0] == TStringBuf(PREFIX),
                    TUnsupportedException());

        TString decoded = Base64ToBin(tokens[1]);
        Y_ENSURE_EX(!decoded.empty(),
                    TMalformedException() << "SshPublicKey: encoded part is invalid base64");
        unsigned long int cursor = 0ul;

        ui32 signature_len = ReadUint32(decoded.c_str());
        Y_ENSURE_EX(signature_len == PREFIX.length(),
                    TMalformedException()
                        << "SshPublicKey ill-formed: encoded prefix len doesn't match to open prefix: "
                        << formatedKey);

        cursor += sizeof(ui32);

        TString signature = TString(decoded.c_str() + cursor, signature_len);

        Y_ENSURE_EX(signature.compare(0, PREFIX.length(), PREFIX) == 0,
                    TMalformedException()
                        << "SshPublicKey ill-formed: encoded prefix doesn't match to open prefix");

        cursor += signature_len;

        ui32 exponent_len = ReadUint32(decoded.c_str() + cursor);
        cursor += sizeof(ui32);

        TNums nums;
        nums.E = ReadBn(decoded.c_str() + cursor, exponent_len);

        cursor += exponent_len;
        ui32 mod_len = ReadUint32(decoded.c_str() + cursor);

        cursor += sizeof(ui32);
        nums.N = ReadBn(decoded.c_str() + cursor, mod_len);

        return nums;
    }

    ui32 TSshPublicKey::ReadUint32(const char* data) {
        const unsigned char* tmp = reinterpret_cast<const unsigned char*>(data);
        return (ui32)((tmp[0] << 24) | (tmp[1] << 16) | (tmp[2] << 8) | tmp[3]);
    }

    TSshPublicKey::TNums::TNum TSshPublicKey::ReadBn(const char* data, unsigned long len) {
        std::stringstream ss;

        ss << std::hex << std::uppercase << std::setfill('0');
        for (size_t idx = 0; idx < len; ++idx, ++data) {
            ss << std::setw(2) << (((int)*data) & 0xff);
        }

        BIGNUM* tmp = nullptr;
        BN_hex2bn(&tmp, ss.str().c_str());
        Y_ENSURE(tmp, "SshPublicKey ill-formed: value is not BN");

        return TNums::TNum(tmp, BN_free);
    }

    void TSshPublicKey::InitKey(TSshPublicKey::TNums&& nums) {
        std::unique_ptr<RSA, std::function<void(RSA*)>> rsa(RSA_new(), RSA_free);
        Y_ENSURE(rsa, "SshPublicKey: RSA_new error");
        RSA_set0_key(rsa.get(), nums.N.release(), nums.E.release(), nullptr);

        Pkey.reset(EVP_PKEY_new());
        Y_ENSURE(Pkey, "SshPublicKey: EVP_PKEY_new error");

        Y_ENSURE(EVP_PKEY_assign_RSA(Pkey.get(), rsa.get()) == 1,
                 "SshPublicKey: EVP_PKEY_assign_RSA error: "
                     << ERR_reason_error_string(ERR_get_error()));
        rsa.release(); // pkey_ owns rsa struct
    }
}
