#include "symmetric.h"
#include "util/system/yassert.h"

#include <openssl/conf.h>
#include <openssl/err.h>
#include <openssl/evp.h>

#include <cstdint>
#include <mutex>

namespace {
    constexpr uint64_t AES_FULL_RANDOM_KEY_LENGTH = 32;
    constexpr uint64_t AES_IV_LENGTH = 16;
    constexpr uint64_t AES_KEY_LENGTH = 16;

    class AESCrypto128: public quasar::SymmetricCrypto {
        struct EVP_CIPHER_CTX_Deleter {
            void operator()(EVP_CIPHER_CTX* ptr) {
                EVP_CIPHER_CTX_free(ptr);
            }
        };
        using EVP_CIPHER_CTX_Ptr = std::unique_ptr<EVP_CIPHER_CTX, EVP_CIPHER_CTX_Deleter>;
        static constexpr unsigned KeySize = 128 / 8;
        static constexpr unsigned BlockSize = 128 / 8;
        std::uint8_t key_[KeySize], iv_[BlockSize];
        EVP_CIPHER_CTX_Ptr ctx;

        static EVP_CIPHER_CTX_Ptr makePtr() {
            return EVP_CIPHER_CTX_Ptr(EVP_CIPHER_CTX_new());
        }

        void setup(const std::string& secret) {
            Y_VERIFY(secret.length() >= AES_FULL_RANDOM_KEY_LENGTH);
            std::memcpy(iv_, secret.c_str(), AES_IV_LENGTH);
            std::memcpy(key_, secret.c_str() + AES_IV_LENGTH, AES_KEY_LENGTH);
        }

    public:
        AESCrypto128(const std::string& secret)
            : iv_{0x66, 0xf4, 0x61, 0x78, 0x72, 0xda, 0x6e, 0xe5, 0xcf, 0x67, 0x8a, 0x82, 0x28, 0x20, 0x25, 0x08}
            , // just junk
            ctx{makePtr()}
        {
            setup(secret);
        }

        std::string encrypt(const std::string& src) override {
            if (EVP_EncryptInit_ex(ctx.get(), EVP_aes_128_cbc(), nullptr, key_, iv_) != 1) {
                throw std::runtime_error("Failed to EVP_EncryptInit");
            }
            std::string rval(src.size() + BlockSize, ' ');
            int rvalLen = rval.size();
            if (EVP_EncryptUpdate(ctx.get(), (std::uint8_t*)rval.data(), &rvalLen, (const std::uint8_t*)src.data(), (int)src.size()) != 1) {
                throw std::runtime_error("Failed to EVP_EncryptUpdate");
            }
            const int mostLen = rvalLen;
            if (EVP_EncryptFinal_ex(ctx.get(), (std::uint8_t*)rval.data() + rvalLen, &rvalLen) != 1) {
                throw std::runtime_error("Failed to EVP_EncryptFinal");
            }
            rval.resize(mostLen + rvalLen);
            return rval;
        }

        std::string decrypt(const std::string& src) override {
            if (src.size() % BlockSize) {
                throw std::runtime_error("Not padded encrypted data");
            }
            if (EVP_DecryptInit_ex(ctx.get(), EVP_aes_128_cbc(), nullptr, key_, iv_) != 1) {
                throw std::runtime_error("Failed to EVP_DecryptInit");
            }
            std::string rval(src.size(), ' ');
            int rvalLen = rval.size();
            if (EVP_DecryptUpdate(ctx.get(), (std::uint8_t*)rval.data(), &rvalLen, (const std::uint8_t*)src.data(), (int)src.size()) != 1) {
                throw std::runtime_error("Failed to EVP_DecryptUpdate");
            }
            if (rvalLen == (int)src.size()) {
                return rval;
            }
            const int mostLen = rvalLen;
            if (EVP_DecryptFinal_ex(ctx.get(), (std::uint8_t*)rval.data() + rvalLen, &rvalLen) != 1) {
                throw std::runtime_error("Failed to EVP_DecryptFinal");
            }
            rval.resize(mostLen + rvalLen);
            return rval;
        }

        void changeSecret(const std::string& secret) override {
            setup(secret);
        }
    };

    // extract it to header when needed
    template <typename Crypto_>
    class LockedCrypto: public Crypto_ {
        std::mutex mutex_;
        using Guard = std::lock_guard<std::mutex>;

    public:
        template <typename... Params_>
        LockedCrypto(Params_... params)
            : Crypto_(params...)
                  {};

        std::string encrypt(const std::string& src) override {
            Guard lock(mutex_);
            return Crypto_::encrypt(src);
        }

        std::string decrypt(const std::string& src) override {
            Guard lock(mutex_);
            return Crypto_::decrypt(src);
        }

        void changeSecret(const std::string& secret) override {
            Guard lock(mutex_);
            Crypto_::changeSecret(secret);
        }
    };

    class NullCrypto: public quasar::SymmetricCrypto {
    public:
        std::string encrypt(const std::string& src) override {
            return src;
        }

        std::string decrypt(const std::string& src) override {
            return src;
        }

        void changeSecret(const std::string& /*secret*/) override {
        }
    };
} // namespace

namespace quasar {
    SymmetricCrypto::UniquePtr SymmetricCrypto::makeNull() {
        return std::make_unique<NullCrypto>();
    }

    SymmetricCrypto::UniquePtr SymmetricCrypto::makeAES128(const std::string& cryptoKey) {
        return std::make_unique<AESCrypto128>(cryptoKey);
    }

    SymmetricCrypto::UniquePtr SymmetricCrypto::makeAES128Locked(const std::string& cryptoKey) {
        return std::make_unique<LockedCrypto<AESCrypto128>>(cryptoKey);
    }

} // namespace quasar
