#include "totp_encryptor.h"

#include "totp_profile.h"

#include <passport/infra/daemons/blackbox/src/misc/strings.h>
#include <passport/infra/daemons/blackbox/src/misc/utils.h>
#include <passport/infra/daemons/blackbox/src/protobuf/totp_profile.pb.h>

#include <passport/infra/libs/cpp/utils/crypto/hash.h>
#include <passport/infra/libs/cpp/utils/log/global.h>
#include <passport/infra/libs/cpp/utils/string/coder.h>
#include <passport/infra/libs/cpp/utils/string/split.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <contrib/libs/openssl/include/openssl/err.h>
#include <contrib/libs/openssl/include/openssl/evp.h>

#include <library/cpp/openssl/holders/evp.h>

#include <iomanip>
#include <sstream>

namespace NPassport::NBb {
    namespace {
        const char* SSLErrorStr() {
            return ERR_reason_error_string(ERR_get_error());
        }
    }

    TTotpEncryptor::TTotpEncryptor(NAuth::TKeyMap&& aesKeys, NAuth::TKeyMap&& macKeys)
        : AesKeys_(std::move(aesKeys))
        , MacKeys_(std::move(macKeys))
    {
        TLog::Info("TOTP: successfully loaded %lu AES key(s) and %lu HMAC key(s)", AesKeys_.GetSize(), MacKeys_.GetSize());
    }

    static const TTotpProfile UNINITED_PROFILE = TTotpProfile(0, "");

    TTotpProfile TTotpEncryptor::Decrypt(ui64 uid, const TString& value) const {
        int ver;
        TString block;
        if (!DecryptToBlock(block, ver, value, uid)) {
            return UNINITED_PROFILE;
        }

        if (ver < 3) {
            return ParseBlockV12(uid, block);
        }

        return ParseBlockV3(uid, block);
    }

    TString TTotpEncryptor::Encrypt(ui64 uid, const TString& secret) const {
        return EncryptBlock(2, PrepareBlockV12(uid, secret));
    }

    TString TTotpEncryptor::Encrypt(const TTotpProfile& profile) const {
        if (!profile.IsInited() || profile.Data_.empty()) {
            return TStrings::EMPTY;
        }
        return EncryptBlock(3, PrepareBlockV3(profile));
    }

    bool TTotpEncryptor::DecryptToBlock(TString& block, int& ver, const TString& value, ui64 uid) const {
        // check value format sanity
        std::vector<TString> fields = NUtils::ToVector(value, ':');

        if (fields.size() != 6 || fields[0].size() != 1) {
            TLog::Error("TOTP Error: bad encrypted secret value format: %s. uid: %lu", value.c_str(), uid);
            return false;
        }

        const TString& versionStr = fields[0];
        int version = versionStr[0] - '0';
        if (version < 1 || version > 3) {
            TLog::Error("TOTP Error: bad encrypted secret version: %d. uid: %lu", version, uid);
            return false;
        }

        // find keys
        if (!AesKeys_.HasKey(fields[1])) {
            TLog::Error("TOTP Error: bad encryption key id: %s. uid: %lu", fields[1].c_str(), uid);
            return false;
        }
        const TStringBuf aesKey = AesKeys_.GetKey(fields[1]);

        if (!MacKeys_.HasKey(fields[2])) {
            TLog::Error("TOTP Error: bad hmac key id: %s. uid: %lu", fields[2].c_str(), uid);
            return false;
        }
        const TStringBuf macKey = MacKeys_.GetKey(fields[2]);

        TString iv = NUtils::Base64ToBin(fields[3]);
        TString aesblock = NUtils::Base64ToBin(fields[4]);
        TString macblock = NUtils::Base64ToBin(fields[5]);

        if (iv.size() != 16 || aesblock.size() < 32 || macblock.size() != 32) {
            TLog::Error("TOTP Error: bad encrypted secret value fields length: %s. uid: %lu", value.c_str(), uid);
            return false;
        }

        // check HMAC
        TString hmac;

        if (version == 1) {
            hmac = NUtils::TCrypto::HmacSha256(macKey, aesblock);
        } else {
            size_t lastcolon = value.rfind(':');
            hmac = NUtils::TCrypto::HmacSha256(macKey, TStringBuf(value.data(), lastcolon));
        }

        if (!NUtils::SecureCompare(hmac, macblock)) {
            TLog::Error("TOTP Error: hash sum mismatch for value: %s. uid: %lu", value.c_str(), uid);
            return false;
        }

        // decrypt value
        TString rawData(aesblock.size(), 0);
        int secretLen = 0;
        int tailLen = 0;

        NOpenSSL::TEvpCipherCtx ctx;

        EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), nullptr, (unsigned char*)aesKey.data(), (unsigned char*)iv.data());

        if (!EVP_DecryptUpdate(ctx, (unsigned char*)rawData.data(), &secretLen, (unsigned char*)aesblock.data(), aesblock.size())) {
            TLog::Error("TOTP Error: DecryptUpdate failed: %s. uid: %lu", SSLErrorStr(), uid);
            return false;
        }

        if (!EVP_DecryptFinal_ex(ctx, (unsigned char*)rawData.data() + secretLen, &tailLen)) {
            TLog::Error("TOTP Error: DecryptFinal failed: %s. uid: %lu", SSLErrorStr(), uid);
            return false;
        }

        rawData.resize(secretLen + tailLen);
        block = std::move(rawData);
        ver = version;
        return true;
    }

    TTotpProfile TTotpEncryptor::ParseBlockV12(ui64 uid, TString& block) {
        // check decrypted uid
        ui64 secretUid = *(i64*)(block.data() + 16);

        if (uid != secretUid) {
            TLog::Error("TOTP Error: value uid mismatch: encrypted for %lu, checked with %lu", secretUid, uid);
            return UNINITED_PROFILE;
        }

        // actual secret is first 128 bit block of decrypted value
        block.resize(16);

        TSecretAndPin secretAndPin = SplitToSecretAndPin(block);

        TTotpProfile res(secretUid, secretAndPin.second);
        TTotpProfile::TSecretData* secr = res.AddSecret();
        if (!secr) {
            TLog::Error("TOTP Error: can't parse decrypted block v12: %s. uid: %lu",
                        res.LastError().c_str(),
                        secretUid);
            return UNINITED_PROFILE;
        }
        secr->TotpKey_ = std::move(block);
        secr->Secret_ = std::move(secretAndPin.first);
        secr->Pin_ = std::move(secretAndPin.second);

        return res;
    }

    TTotpProfile TTotpEncryptor::ParseBlockV3(ui64 uid, const TString& block) {
        totp_proto::TotpProfile proto;
        // Sanitizing
        if (!proto.ParseFromArray(block.data(), block.size())) {
            TLog::Error("TOTP Error: can't parse protobuf. uid: %lu", uid);
            return UNINITED_PROFILE;
        }
        if (!proto.has_uid() || !proto.has_pin() || proto.secrets_size() == 0) {
            TLog::Error("TOTP Error: protobuf is broken. Has uid: %s. Has pin: %s. Has secrets: %s. uid from request: %ld",
                        proto.has_uid() ? "true" : "false",
                        proto.has_pin() ? "true" : "false",
                        proto.secrets_size() > 0 ? "true" : "false",
                        uid);
            return UNINITED_PROFILE;
        }

        // Product logic
        if (proto.uid() != uid) {
            TLog::Error("TOTP Error: value uid mismatch: encrypted for %lu, checked with %lu", proto.uid(), uid);
            return UNINITED_PROFILE;
        }

        TTotpProfile profile(proto.uid(), proto.pin());
        for (int idx = 0; idx < proto.secrets_size(); ++idx) {
            TTotpProfile::TSecretData* sd = profile.AddSecret();
            if (!sd) {
                TLog::Error("TOTP Error: can't parse decrypted block v3: %s. uid: %lu",
                            profile.LastError().c_str(),
                            proto.uid());
                return UNINITED_PROFILE;
            }
            sd->Secret_ = proto.secrets(idx).secret();
            sd->SecretId_ = proto.secrets(idx).id();
            sd->Created_ = proto.secrets(idx).created();
        }

        return profile;
    }

    TString TTotpEncryptor::EncryptBlock(ui32 version, const TString& plaintext) const {
        static const TString DELIMITER = ":";
        if (plaintext.empty()) {
            return TStrings::EMPTY;
        }

        // generate random 128-bit IV
        TString iv = NUtils::TCrypto::RandBytes(16);

        if (iv.empty()) {
            TLog::Error("TOTP Error: failed to generate random IV: %s", SSLErrorStr());
            return TStrings::EMPTY;
        }

        // encrypt value
        TString block(plaintext.size() + 16, 0);
        int blockLen = 0;
        int tailLen = 0;

        NOpenSSL::TEvpCipherCtx ctx;

        EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), nullptr, (unsigned char*)AesKeys_.GetDefaultKey().data(), (unsigned char*)iv.data());

        if (!EVP_EncryptUpdate(ctx, (unsigned char*)block.data(), &blockLen, (unsigned char*)plaintext.data(), plaintext.size())) {
            TLog::Error("TOTP Error: EncryptUpdate failed: %s", SSLErrorStr());
            return TStrings::EMPTY;
        }

        if (!EVP_EncryptFinal_ex(ctx, (unsigned char*)block.data() + blockLen, &tailLen)) {
            TLog::Error("TOTP Error: EncryptFinal failed: %s", SSLErrorStr());
            return TStrings::EMPTY;
        }

        blockLen += tailLen;

        block.resize(blockLen);

        // make up the resulting value
        TString value = NUtils::CreateStrExt(
            256,
            version,
            DELIMITER,
            AesKeys_.GetDefaultId(),
            DELIMITER,
            MacKeys_.GetDefaultId(),
            DELIMITER,
            NUtils::BinToBase64(iv),
            DELIMITER,
            NUtils::BinToBase64(block));

        const TString hmac = version == 1
                                 ? NUtils::TCrypto::HmacSha256(MacKeys_.GetDefaultKey(), block)
                                 : NUtils::TCrypto::HmacSha256(MacKeys_.GetDefaultKey(), value);

        NUtils::Append(value,
                       DELIMITER,
                       NUtils::BinToBase64(hmac));

        return value;
    }

    TString TTotpEncryptor::PrepareBlockV12(ui64 uid, const TString& secret) {
        if (secret.size() != 16) {
            TLog::Error("TOTP Error: secret of unsupported length: %lu", secret.size());
            return TStrings::EMPTY;
        }

        TString plaintext(secret);
        plaintext.resize(16 + sizeof(ui64), 0);

        *(i64*)(plaintext.data() + 16) = uid;
        return plaintext;
    }

    TString TTotpEncryptor::PrepareBlockV3(const TTotpProfile& profile) {
        totp_proto::TotpProfile proto;
        proto.set_uid(profile.Uid_);
        proto.set_pin(TString(profile.Pin()));
        proto.set_version(1);

        std::size_t sdCount = profile.Data_.size();
        for (std::size_t idx = 0; idx < sdCount; ++idx) {
            totp_proto::Secret* secr = proto.add_secrets();
            if (secr == nullptr) {
                TLog::Error("TOTP Error: Can't add new secret to protobuf. uid: %lu", profile.Uid());
                return TStrings::EMPTY;
            }
            const TTotpProfile::TSecretData& sd = profile.Data_[idx];
            secr->set_id(sd.SecretId());
            secr->set_secret(TString(sd.Secret()));
            secr->set_created(sd.Created());
        }

        return proto.SerializeAsString();
    }

    TTotpEncryptor::TSecretAndPin TTotpEncryptor::SplitToSecretAndPin(const TString& value) {
        unsigned __int128 key = 0;
        for (const char c : value) {
            key <<= 8;
            key |= (unsigned char)c;
        }

        // Pin
        TSecretAndPin pair;
        std::ostringstream output;
        output << std::setw(4) << std::setfill('0') << (unsigned)(key % 10000);
        pair.second = output.str();

        // Secret
        key /= 10000;

        // Bug in passport (python) on generating of secret may return less than 15 bytes
        int secretSize = 0;
        unsigned __int128 keyTmp = key;
        for (; secretSize < int(sizeof(key)) && keyTmp; ++secretSize) {
            keyTmp >>= 8;
        }

        TString secretStr;
        secretStr.resize(secretSize);
        for (int i = secretSize - 1; i >= 0; --i) {
            secretStr[i] = key & 0xFF;
            key >>= 8;
        }
        pair.first = std::move(secretStr);

        return pair;
    }

}
