#include "hash.h"

#include <contrib/libs/openssl/include/openssl/evp.h>
#include <contrib/libs/openssl/include/openssl/hmac.h>
#include <contrib/libs/openssl/include/openssl/md5.h>
#include <contrib/libs/openssl/include/openssl/rand.h>
#include <contrib/libs/openssl/include/openssl/sha.h>

#include <library/cpp/openssl/holders/evp.h>

#include <memory>

namespace NPassport::NUtils {
    TString TCrypto::RandBytes(size_t size) {
        TString buf(size, 0);

        if (!RAND_bytes((unsigned char*)buf.data(), size)) {
            buf.clear(); // no random data
        }
        return buf;
    }

    TString TCrypto::HmacSha256(const TStringBuf key, const TStringBuf data) {
        TString value(EVP_MAX_MD_SIZE, 0);
        unsigned macLen = 0;

        if (!::HMAC(EVP_sha256(), key.data(), key.size(), (unsigned char*)data.data(), data.size(), (unsigned char*)value.data(), &macLen)) {
            return TString();
        }

        if (macLen != EVP_MAX_MD_SIZE) {
            value.resize(macLen);
        }
        return value;
    }

    TString TCrypto::HmacSha1(const TStringBuf key, const TStringBuf data) {
        TString value(EVP_MAX_MD_SIZE, 0);
        unsigned macLen = 0;

        if (!::HMAC(EVP_sha1(), key.data(), key.size(), (unsigned char*)data.data(), data.size(), (unsigned char*)value.data(), &macLen)) {
            return TString();
        }

        if (macLen != EVP_MAX_MD_SIZE) {
            value.resize(macLen);
        }
        return value;
    }

    TString TCrypto::Sha256(const TStringBuf str) {
        TString buf(SHA256_DIGEST_LENGTH, 0);

        SHA256((const unsigned char*)str.data(), str.size(), (unsigned char*)buf.data());

        return buf;
    }

    TString TCrypto::Md5(const TStringBuf str) {
        TString md5buf(MD5_DIGEST_LENGTH, 0);

        MD5((const unsigned char*)str.data(), str.size(), (unsigned char*)md5buf.data());

        return md5buf;
    }

    static const unsigned MinKeySize = 32; // 256 bit key minimum
    static const unsigned MinIVSize = 12;  // 96 bit IV minimum (default for AES GCM)
    static const unsigned MinTagSize = 16; // 128 bit minimum (default for AES GCM)

    bool TCrypto::EncryptGcm(const TStringBuf key, const TStringBuf input, TCrypto::TCiphertext& output, TString* err_msg) {
        return EncryptGcm(key, input, output, TStringBuf(), err_msg);
    }

    bool TCrypto::EncryptGcm(const TStringBuf key, const TStringBuf input, TCrypto::TCiphertext& output, const TStringBuf aa_data, TString* err_msg) {
        output.Iv.clear();
        output.Text.clear();
        output.Tag.clear();

        // check parameters
        if (key.size() < MinKeySize) {
            if (err_msg) {
                err_msg->assign("too short key given, 256 bit needed");
            }
            return false;
        }

        // generate random IV, default length is 12 bytes
        TString iv = RandBytes(MinIVSize);

        if (iv.empty()) {
            if (err_msg) {
                err_msg->assign("unable to generate RAND_bytes()");
            }
            return false;
        }

        NOpenSSL::TEvpCipherCtx ctx;

        // init encrypt context with given key and iv
        if (!EVP_EncryptInit_ex(ctx, EVP_aes_256_gcm(), nullptr, (unsigned char*)key.data(), (unsigned char*)iv.data())) {
            if (err_msg) {
                err_msg->assign("unable to perform EVP_EncryptInit_ex()");
            }
            return false;
        }

        // openssl doc says we need to have at least one more block size available for encryption
        TString ciphertext(input.size() + EVP_MAX_BLOCK_LENGTH, 0);
        int plainLen = 0;

        if (aa_data) {
            // add AAD data (does not affect output)
            if (!EVP_EncryptUpdate(ctx, nullptr, &plainLen, (unsigned char*)aa_data.data(), aa_data.size())) {
                if (err_msg) {
                    err_msg->assign("unable to perform EVP_EncryptUpdate(AAD)");
                }
                return false;
            }
        }

        // actually encrypt data
        if (!EVP_EncryptUpdate(ctx, (unsigned char*)ciphertext.data(), &plainLen, (unsigned char*)input.data(), input.size())) {
            if (err_msg) {
                err_msg->assign("unable to perform EVP_EncryptUpdate()");
            }
            return false;
        }

        int tailLen = 0;
        // finalize the encryption
        if (!EVP_EncryptFinal_ex(ctx, (unsigned char*)ciphertext.data() + plainLen, &tailLen)) {
            if (err_msg) {
                err_msg->assign("unable to perform EVP_EncryptFinal_ex()");
            }
            return false;
        }

        TString tag(MinTagSize, 0);

        // get the tag
        if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, tag.size(), (void*)tag.data())) {
            if (err_msg) {
                err_msg->assign("unable to perform EVP_CIPHER_CTX_ctrl(gcm_get_tag)");
            }
            return false;
        }

        ciphertext.resize(plainLen + tailLen);
        output.Iv = std::move(iv);
        output.Text = std::move(ciphertext);
        output.Tag = std::move(tag);

        if (err_msg) {
            err_msg->assign("OK");
        }

        return true;
    }

    bool TCrypto::DecryptGcm(const TStringBuf key, const TCrypto::TCiphertext& input, TString& output, TString* err_msg) {
        return DecryptGcm(key, input, output, TStringBuf(), err_msg);
    }

    bool TCrypto::DecryptGcm(const TStringBuf key, const TCrypto::TCiphertext& input, TString& output, const TStringBuf aa_data, TString* err_msg) {
        output.clear();
        NOpenSSL::TEvpCipherCtx ctx;

        // check parameters
        if (key.size() < MinKeySize) {
            if (err_msg) {
                err_msg->assign("too short key given, 256 bit needed");
            }
            return false;
        }

        if (input.Iv.size() < MinIVSize) {
            if (err_msg) {
                err_msg->assign("too short IV given, at least 96 bit needed");
            }
            return false;
        }

        if (input.Tag.size() < MinTagSize) {
            if (err_msg) {
                err_msg->assign("too short tag given, at least 128 bit needed");
            }
            return false;
        }

        // init the decryption process
        if (!EVP_DecryptInit_ex(ctx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr)) {
            if (err_msg) {
                err_msg->assign("unable to perform EVP_DecryptInit_ex()");
            }
            return false;
        }

        // try to set custom IV length if it happens to be non-standard
        if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, input.Iv.size(), nullptr)) {
            if (err_msg) {
                err_msg->assign("unable to perform EVP_CIPHER_CTX_ctrl(set_ivlen)");
            }
            return false;
        }

        // init decrypt context with given key and iv
        if (!EVP_DecryptInit_ex(ctx, nullptr, nullptr, (unsigned char*)key.data(), (unsigned char*)input.Iv.data())) {
            if (err_msg) {
                err_msg->assign("unable to perform EVP_DecryptInit_ex()");
            }
            return false;
        }

        // openssl doc says we need to have at least one more block size available for decryption
        TString plaintext(input.Text.size() + EVP_MAX_BLOCK_LENGTH, 0);
        int plainLen = 0;

        if (aa_data) {
            // add AAD data (does not affect output)
            if (!EVP_DecryptUpdate(ctx, nullptr, &plainLen, (unsigned char*)aa_data.data(), aa_data.size())) {
                if (err_msg) {
                    err_msg->assign("unable to perform EVP_DecryptUpdate(AAD)");
                }
                return false;
            }
        }

        // actually decrypt data
        if (!EVP_DecryptUpdate(ctx, (unsigned char*)plaintext.data(), &plainLen, (unsigned char*)input.Text.data(), input.Text.size())) {
            if (err_msg) {
                err_msg->assign("unable to perform EVP_DecryptUpdate()");
            }
            return false;
        }

        // check tag
        if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, input.Tag.size(), (void*)input.Tag.data())) {
            if (err_msg) {
                err_msg->assign("unable to perform EVP_CIPHER_CTX_ctrl(gcm_set_tag)");
            }
            return false;
        }

        int tailLen = 0;
        // finalize the decryption
        if (!EVP_DecryptFinal_ex(ctx, (unsigned char*)plaintext.data() + plainLen, &tailLen)) {
            if (err_msg) {
                err_msg->assign("unable to perform EVP_DecryptFinal_ex()");
            }
            return false;
        }

        plaintext.resize(plainLen + tailLen);
        output = std::move(plaintext);

        if (err_msg) {
            err_msg->assign("OK");
        }

        return true;
    }

    bool TCrypto::DecryptCbc(const TStringBuf key,
                             const TStringBuf iv,
                             const TStringBuf text,
                             TString& output,
                             TString* errMsg) {
        output.clear();
        NOpenSSL::TEvpCipherCtx ctx;

        // check parameters
        if (key.size() != 32) {
            if (errMsg) {
                errMsg->assign("key required to be 32 bytes");
            }
            return false;
        }

        if (iv.size() != 16) {
            if (errMsg) {
                errMsg->assign("iv required to be 16 bytes");
            }
            return false;
        }

        // init the decryption process
        if (!EVP_DecryptInit_ex(ctx,
                                EVP_aes_256_cbc(),
                                nullptr,
                                (const unsigned char*)key.data(),
                                (const unsigned char*)iv.data()))
        {
            if (errMsg) {
                errMsg->assign("unable to perform EVP_DecryptInit_ex()");
            }
            return false;
        }

        // openssl doc says we need to have at least one more block size available for decryption
        TString plaintext(text.size() + EVP_MAX_BLOCK_LENGTH, 0);
        int plainLen = 0;

        // actually decrypt data
        if (!EVP_DecryptUpdate(ctx,
                               (unsigned char*)plaintext.data(),
                               &plainLen,
                               (unsigned char*)text.data(),
                               text.size()))
        {
            if (errMsg) {
                errMsg->assign("unable to perform EVP_DecryptUpdate()");
            }
            return false;
        }

        int tailLen = 0;
        // finalize the decryption
        if (!EVP_DecryptFinal_ex(ctx, (unsigned char*)plaintext.data() + plainLen, &tailLen)) {
            if (errMsg) {
                errMsg->assign("unable to perform EVP_DecryptFinal_ex()");
            }
            return false;
        }

        plaintext.resize(plainLen + tailLen);
        output = std::move(plaintext);

        if (errMsg) {
            errMsg->assign("OK");
        }

        return true;
    }
}
