#include "rsa.h"

#include <contrib/libs/openssl/include/openssl/aes.h>
#include <contrib/libs/openssl/include/openssl/hmac.h>
#include <contrib/libs/openssl/include/openssl/rand.h>

#include <library/cpp/openssl/holders/evp.h>
#include <library/cpp/string_utils/base64/base64.h>

#include <util/generic/buffer.h>
#include <util/stream/buffer.h>

#include <array>

NOpenssl::IAbstractCipher::TFactory::TRegistrator<NOpenssl::TAESCipher> NOpenssl::TAESCipher::Registrator(NOpenssl::ECipherType::AES);

TString NOpenssl::CalcSHA256(const TString& data) {
    unsigned char digest[SHA256_DIGEST_LENGTH];
    SHA256((const unsigned char*)data.data(), data.size(), digest);
    TStringBuf result((const char*)digest, SHA256_DIGEST_LENGTH);

    return TString(result);
}

TString NOpenssl::CalcSHA1(const TString& data) {
    SHA_CTX ctx;
    SHA1_Init(&ctx);
    SHA1_Update(&ctx, reinterpret_cast<const unsigned char*>(data.data()), data.size());

    unsigned char sha1[SHA_DIGEST_LENGTH];
    SHA1_Final(sha1, &ctx);
    TStringBuf result((const char*)sha1, SHA_DIGEST_LENGTH);
    return TString(result);
}

TString NOpenssl::CalcHMACToken(const TString& secret, const TString& sign) {
    unsigned char hmacBuffer[EVP_MAX_MD_SIZE];
    unsigned int hmacLength;
    HMAC(EVP_sha1(), secret.data(), secret.size(), (const unsigned char*)sign.data(), sign.size(), hmacBuffer, &hmacLength);
    return Base64Encode(TStringBuf((const char*)hmacBuffer, hmacLength));
}

NOpenssl::TRSAPtr NOpenssl::GenerateRSA() {
    NOpenssl::TRSAPtr rsa(RSA_new());
    BIGNUM *e = BN_new();
    BN_set_word(e, 7);
    RSA_generate_key_ex(rsa.Get(), 2048, e, nullptr);
    BN_free(e);
    return rsa;
}

TString NOpenssl::GenerateAESKey(ui32 len) {
    TVector<ui8> b;
    b.resize(len);
    RAND_bytes(b.data(), b.size());
    return TString((char*)b.data(), b.size());
}

namespace {
    template <class T>
    bool AESImpl(const TString& key, const T& input, TString& output, bool decrypt) {
        TString iv(16, '\0');
        NOpenSSL::TEvpCipherCtx context;
        if (EVP_CipherInit_ex(context, EVP_aes_256_cbc(), nullptr, (unsigned char*)key.data(), (unsigned char*)iv.data(), decrypt ? 0 : 1) != 1) {
            return false;
        }

        ui32 itersCount = (input.Size() / AES_BLOCK_SIZE);
        if (input.size() % AES_BLOCK_SIZE > 0) {
            ++itersCount;
        }
        TBufferOutput result;
        for (ui32 i = 0; i < itersCount; ++i) {
            ui32 shift = i * AES_BLOCK_SIZE;
            int bytes = 0;

            auto in = reinterpret_cast<const unsigned char*>(input.data()) + shift;
            auto inl = std::min<int>(AES_BLOCK_SIZE, input.Size() - shift);

            std::array<unsigned char, 2 * AES_BLOCK_SIZE> buffer;
            buffer.fill(0);

            Y_ASSERT(inl + AES_BLOCK_SIZE - 1 < static_cast<int>(buffer.size()));
            if (EVP_CipherUpdate(context, buffer.data(), &bytes, in, inl) != 1) {
                return false;
            }
            Y_ASSERT(bytes <= static_cast<int>(buffer.size()));
            result.Write(buffer.data(), bytes);
        }
        output = TString(result.Buffer().data(), result.Buffer().size());
        output.append(AES_BLOCK_SIZE, '\0');
        int finalBytes = 0;
        if (EVP_CipherFinal_ex(context, (unsigned char*)output.data() + result.Buffer().size() , &finalBytes) != 1) {
            return false;
        }
        output.resize(result.Buffer().size() + finalBytes);
        return true;
    }
}

bool NOpenssl::AESDecrypt(const TString& key, const TString& input, TString& output) {
    return AESImpl(key, input, output, true);
}

bool NOpenssl::AESEncrypt(const TString& key, const TString& input, TString& output) {
    return AESImpl(key, input, output, false);
}

bool NOpenssl::AESEncrypt(const TString& key, const TBuffer& input, TString& output) {
    return AESImpl(key, input, output, false);
}
