#include "codec.h"

#include <travel/hotels/lib/cpp/util/base64.h>
#include <travel/hotels/lib/cpp/util/compress.h>
#include <travel/hotels/lib/cpp/util/secret_reader.h>

#include <openssl/aes.h>
#include <openssl/evp.h>
#include <openssl/rand.h>
#include <openssl/sha.h>

#include <library/cpp/openssl/holders/evp.h>

#include <util/generic/string.h>
#include <util/generic/vector.h>
#include <util/generic/yexception.h>
#include <util/stream/str.h>
#include <util/thread/singleton.h>

namespace NTravel {
namespace NEncryption {

constexpr size_t NonceLen = 12;
constexpr size_t TagLen = 16;

TString MakeNonce() {
    unsigned char b[NonceLen];
    RAND_bytes(b, sizeof(b));
    return TString((char*)b, sizeof(b));
}

struct TEvpCipherCtx: NOpenSSL::TEvpCipherCtx {
};

static
EVP_CIPHER_CTX* GetEvpCipherCtx() {
    return *FastTlsSingleton<TEvpCipherCtx>();
}

TString Encrypt(const TString& plaintext, const TString& key, const TString& iv, TString& tag)
{
    auto ctx = GetEvpCipherCtx();
    int len;
    int ciphertext_len;

    /* Initialise the encryption operation. */
    if(1 != EVP_EncryptInit_ex(ctx, EVP_aes_128_gcm(), NULL, NULL, NULL)) {
        throw yexception() << "Failed to EVP_EncryptInit_ex";
    }

    /* Set IV length if default 12 bytes (96 bits) is not appropriate */
    if(1 != EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv.size(), NULL)) {
        throw yexception() << "Failed to EVP_CIPHER_CTX_ctrl";
    }

    EVP_CIPHER_CTX_set_padding(ctx, 0);

    /* Initialise key and IV */
    if(1 != EVP_EncryptInit_ex(ctx, NULL, NULL, (const unsigned char*)key.data(), (const unsigned char*)iv.data())) {
        throw yexception() << "Failed to EVP_EncryptInit_ex";
    }

    /* Provide any AAD data. This can be called zero or more times as
     * required
     */
    /*
    if(1 != EVP_EncryptUpdate(ctx, NULL, &len, aad, aad_len)) {
        throw yexception() << "Failed to EVP_EncryptUpdate";
    }*/

    /* Provide the message to be encrypted, and obtain the encrypted output.
     * EVP_EncryptUpdate can be called multiple times if necessary
     */

    TVector<unsigned char> cipherText;
    cipherText.resize(plaintext.size() + 64);
    if(1 != EVP_EncryptUpdate(ctx, &cipherText[0], &len, (const unsigned char*)plaintext.data(), plaintext.size())) {
        throw yexception() << "Failed to EVP_EncryptUpdate";
    }
    ciphertext_len = len;

    /* Finalise the encryption. Normally ciphertext bytes may be written at
     * this stage, but this does not occur in GCM mode
     */
    if(1 != EVP_EncryptFinal_ex(ctx, &cipherText[len], &len)) {
        throw yexception() << "Failed to EVP_EncryptFinal_ex";
    }

    ciphertext_len += len;

    TVector<unsigned char> tagU;
    tagU.resize(TagLen);
    /* Get the tag */
    if(1 != EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, TagLen, &tagU[0])) {
        throw yexception() << "Failed to EVP_CIPHER_CTX_ctrl";
    }
    tag = TString((const char*)&tagU[0], 16);

    return TString((const char*)&cipherText[0], ciphertext_len);
}

TString Decrypt(const TString& encrypted, const TString& key, const TString& iv, const TString& tag)
{
    auto ctx = GetEvpCipherCtx();
    int len;
    int plaintext_len;
    int ret;

    /* Initialise the decryption operation. */
    if(!EVP_DecryptInit_ex(ctx, EVP_aes_128_gcm(), NULL, NULL, NULL)) {
        throw yexception() << "Failed to EVP_DecryptInit_ex";
    }

    EVP_CIPHER_CTX_set_padding(ctx, 0);

    /* Set IV length. Not necessary if this is 12 bytes (96 bits) */
    if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv.size(), NULL)) {
        throw yexception() << "Failed to EVP_CIPHER_CTX_ctrl";
    }

    /* Initialise key and IV */
    if(!EVP_DecryptInit_ex(ctx, NULL, NULL, (const unsigned char*)key.data(), (const unsigned char*)iv.data())) {
        throw yexception() << "Failed to EVP_DecryptInit_ex";
    }
    TVector<unsigned char> plainText;
    plainText.resize(encrypted.size() + 1024);

    /* Provide the message to be decrypted, and obtain the plaintext output.
     * EVP_DecryptUpdate can be called multiple times if necessary
     */
    if(!EVP_DecryptUpdate(ctx, &plainText[0], &len, (const unsigned char*)encrypted.data(), encrypted.size())) {
        throw yexception() << "Failed to EVP_DecryptUpdate";
    }
    plaintext_len = len;


    /* Set expected tag value. Works in OpenSSL 1.0.1d and later */
    if (tag) {
        if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, tag.size(), (void*)tag.data())) {
            throw yexception() << "Failed to EVP_CIPHER_CTX_ctrl";
        }
    }

    /* Finalise the decryption. A positive return value indicates success,
     * anything else is a failure - the plaintext is not trustworthy.
     */
    ret = EVP_DecryptFinal_ex(ctx, &plainText[len], &len);

    if (!tag || ret > 0)
    {
        /* Success */
        return TString((const char*)&plainText[0], plaintext_len + len);
    } else {
        throw yexception() << "Encoded data verification failed: " << ret;
    }
}

TString CalcSHA224(const TString& str) {
    unsigned char hash[SHA224_DIGEST_LENGTH];
    SHA256_CTX ctx;
    SHA224_Init(&ctx);
    SHA224_Update(&ctx, str.data(), str.length());
    SHA224_Final(hash, &ctx);
    return TString((char*)hash, sizeof(hash));
}

TString CalcSHA1(const TString& str) {
    unsigned char hash[20];
    SHA_CTX ctx;
    SHA1_Init(&ctx);
    SHA1_Update(&ctx, str.data(), str.length());
    SHA1_Final(hash, &ctx);
    return TString((char*)hash, sizeof(hash));
}

TString ReadKeyFromFile(const TString& keyFile, bool doSha1) {
    TString key = ReadSecret(keyFile);
    if (doSha1) {
        key = CalcSHA1(key).substr(0, 16);
    }
    if (key.Size() != 16) {
        throw yexception() << "Invalid cipher key size: " << key.size() << " instead of 16";
    }
    return key;
}

TUrlCodec::TUrlCodec(const TString& keyFile)
    : Key_(ReadKeyFromFile(keyFile, false))
{
}

TString TUrlCodec::Encode(const TString& plainData) const {
    TString zipped = ZLibCompress(plainData, ZLib::StreamType::ZLib, 0);
    TString nonce = MakeNonce();
    TString tag;
    TString encoded = Encrypt(zipped, Key_, nonce, tag);
    TString encodedTuple = encoded + tag + nonce;
    TString result = NTravel::Base64EncodeUrlShort(encodedTuple);
    return result;
}

TString TUrlCodec::Decode(const TString& encodedData) const {
    TString encodedTuple = NTravel::Base64DecodeAny(encodedData);
    TString encoded = encodedTuple.substr(0, encodedTuple.size() - TagLen - NonceLen);
    TString tag = encodedTuple.substr(encodedTuple.size() - NonceLen - TagLen, TagLen);
    TString nonce  = encodedTuple.substr(encodedTuple.size() - NonceLen);
    TString decodedZipped = Decrypt(encoded, Key_, nonce, tag);
    TString decoded = ZLibDecompress(decodedZipped, ZLib::StreamType::ZLib);
    return decoded;
}

TTokenCodec::TTokenCodec(const TString& keyFile)
    : Key_(ReadKeyFromFile(keyFile, true))
{
}

TString TTokenCodec::Encode(const TString& plainData) const {
    TString nonce = MakeNonce();
    TString tag;
    TString encoded = Encrypt(plainData, Key_, nonce, tag);
    TString encodedTuple = nonce + encoded + tag;
    TString result = NTravel::Base64EncodeUrlShort(encodedTuple);
    return result;
}

TString TTokenCodec::Decode(const TString& encodedData) const {
    TString encodedTuple = NTravel::Base64DecodeAny(encodedData);
    TString nonce  = encodedTuple.substr(0, NonceLen);
    TString encoded = encodedTuple.substr(NonceLen, encodedTuple.size() - TagLen - NonceLen);
    TString tag = encodedTuple.substr(encodedTuple.size() - TagLen, TagLen);
    TString decoded = Decrypt(encoded, Key_, nonce, tag);
    return decoded;
}

TString TBase64Codec::Encode(const TString& plainData) const {
    return NTravel::Base64EncodeUrlShort(plainData);
}

TString TBase64Codec::Decode(const TString& encodedData) const {
    return NTravel::Base64DecodeAny(encodedData);
}

} //NLabel
} //NTravel
