#include "password_checker.h"

#include "exception.h"
#include "utils.h"

#include <passport/infra/daemons/blackbox/src/crypt/crypt.h>

#include <passport/infra/libs/cpp/argon/argon.h>
#include <passport/infra/libs/cpp/argon/mem_pool.h>
#include <passport/infra/libs/cpp/utils/crypto/hash.h>
#include <passport/infra/libs/cpp/utils/string/coder.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <contrib/libs/openssl/include/openssl/err.h>

#include <util/string/cast.h>

#include <memory>

namespace NPassport::NBb {
    const ui8 MD5_SALT_LEN = 8;
    const char ARGON_ATTR_DELIMITER = ':';

    TPasswordChecker::TPasswordChecker() = default;

    TPasswordChecker::TPasswordChecker(NAuth::TKeyMap&& secrets,
                                       ui32 hashlen,
                                       ui32 tCost,
                                       ui32 mCost)
        : TCost_(tCost)
        , MCost_(mCost)
        , Secrets_(std::move(secrets))
        , Argon2_(std::make_unique<NArgon::TArgon2Factory>(hashlen))
    {
    }

    TPasswordChecker::~TPasswordChecker() = default;

    bool TPasswordChecker::PasswordMatches(const TString& pwd,
                                           const TString& hash,
                                           const TString& uid,
                                           bool supportOldHashes) const {
        if (hash.size() < 30 || hash[1] != ':') {
            return false; // malformed password hash
        }

        char ver = hash[0];
        const char* hashCStr = hash.c_str() + 2;

        switch (ver) {
            case '1': // old md5crypt hashes
                return supportOldHashes ? CheckMD5Hash(pwd, hashCStr) : false;
            case '5': // md5crypt hashes bound to specific uid, space not allowed in uid or pwd so it is safe delimiter
                return supportOldHashes ? CheckMD5Hash(uid + ' ' + pwd, hashCStr) : false;
            case '6': // argon(md5crypt(pwd))
                return CheckArgonWithMD5Hash(pwd, TStringBuf(hash).Skip(2), uid);
            case '7': // argon(md5(pwd))
                return CheckArgonWithRawMD5Hash(pwd, TStringBuf(hash).Skip(2), uid);
        }

        return false; // unknown hashing version
    }

    TString TPasswordChecker::MakeHash(const TString& pwd, const TString& uid, int ver) const {
        switch (ver) {
            case 6:
                if (uid.empty()) {
                    throw TBlackboxError(TBlackboxError::EType::InvalidParams) << "no UID for version=6 hash";
                }
                return "6:" + ArgonWithMd5Crypt(Md5crypt(pwd), uid);
            case 7:
                if (uid.empty()) {
                    throw TBlackboxError(TBlackboxError::EType::InvalidParams) << "no UID for version=7 hash";
                }
                return "7:" + ArgonRaw(Md5Raw(pwd), uid);
            default:
                throw TBlackboxError(TBlackboxError::EType::InvalidParams)
                    << "not supported hash version: " << InvalidValue(ver);
        }
    }

    TString TPasswordChecker::ConvertMD5ToArgon(const TString& md5hash, const TString& uid) const {
        if (uid.empty()) {
            throw TBlackboxError(TBlackboxError::EType::InvalidParams) << "no UID for version=6 hash";
        }
        return "6:" + ArgonWithMd5Crypt(md5hash, uid);
    }

    TString TPasswordChecker::ConvertRawMD5ToArgon(const TString& md5hash, const TString& uid) const {
        if (uid.empty()) {
            throw TBlackboxError(TBlackboxError::EType::InvalidParams) << "no UID for version=7 hash";
        }
        return "7:" + ArgonRaw(md5hash, uid);
    }

    static TString GenRandomData(size_t size) {
        TString res = NUtils::TCrypto::RandBytes(size);

        if (res.empty()) {
            throw TBlackboxError(TBlackboxError::EType::Unknown)
                << "Rand failed: " << ERR_reason_error_string(ERR_get_error());
        }
        return res;
    }

    // crypt salt allowed symbols [a–zA–Z0–9./]
    static const TString SALT_BASE = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789./";

    // MD5Crypt - with salt and unix pwd hash format, salt is generated if NULL
    TString TPasswordChecker::Md5crypt(const TString& pwd, const char* salt) {
        std::unique_ptr<crypt_data> crdata = std::make_unique<crypt_data>();
        crdata->initialized = 0;

        if (salt) {
            const char* hash = crypt_r(pwd.c_str(), salt, crdata.get());
            if (!hash) {
                return {};
            }
            return hash;
        }

        // gen salt
        TString iv = GenRandomData(MD5_SALT_LEN); // 48 bit

        TString generatedSalt("$1$");
        generatedSalt.reserve(generatedSalt.size() + iv.size() + 1);
        for (char c : iv) {
            generatedSalt.push_back(SALT_BASE[static_cast<unsigned char>(c) & 0x3F]);
        }
        generatedSalt.push_back('$');

        return crypt_r(pwd.c_str(), generatedSalt.c_str(), crdata.get());
    }

    // Raw MD5 checksum in hex
    TString TPasswordChecker::Md5Raw(const TString& pwd) {
        TString md5buf = NUtils::TCrypto::Md5(pwd);
        return NUtils::Bin2hex(md5buf);
    }

    // Main Argon2 pwd hash
    // makes Argon2(md5hash, secret, tcost, mcost, uid)
    // writes <secret_id>:<tcost>:<mcost>:<md5crypt_salt>:<argon_salt>:<hash>
    TString TPasswordChecker::ArgonWithMd5Crypt(const TString& md5hash, const TString& uid) const {
        if (!Argon2_) {
            throw TBlackboxError(TBlackboxError::EType::Unknown) << "Argon is not configured";
        }

        TString argonSalt = GenRandomData(16);

        auto ctx = Argon2_->Get(Secrets_.GetDefaultKey(), TCost_, MCost_);
        TString argonHash = ctx.MakeHash(md5hash, argonSalt, uid);

        TString res;
        res.reserve(96);
        res.append(Secrets_.GetDefaultId()).push_back(ARGON_ATTR_DELIMITER);
        res.append(IntToString<10>(TCost_)).push_back(ARGON_ATTR_DELIMITER);
        res.append(IntToString<10>(MCost_)).push_back(ARGON_ATTR_DELIMITER);
        res.append(md5hash, 3, MD5_SALT_LEN).push_back(ARGON_ATTR_DELIMITER); // md5 salt
        res.append(NUtils::BinToBase64(argonSalt, false)).push_back(ARGON_ATTR_DELIMITER);
        res.append(NUtils::BinToBase64(argonHash, false));

        return res;
    }

    // Argon2 over arbitrary value
    // makes Argon2(value, secret, tcost, mcost, uid)
    // writes <secret_id>:<tcost>:<mcost>:<argon_salt>:<hash>
    TString TPasswordChecker::ArgonRaw(const TString& value, const TString& uid) const {
        if (!Argon2_) {
            throw TBlackboxError(TBlackboxError::EType::Unknown) << "Argon is not configured";
        }

        TString argonSalt = GenRandomData(16);

        auto ctx = Argon2_->Get(Secrets_.GetDefaultKey(), TCost_, MCost_);
        TString argonHash = ctx.MakeHash(value, argonSalt, uid);

        TString res;
        res.reserve(96);
        res.append(Secrets_.GetDefaultId()).push_back(ARGON_ATTR_DELIMITER);
        res.append(IntToString<10>(TCost_)).push_back(ARGON_ATTR_DELIMITER);
        res.append(IntToString<10>(MCost_)).push_back(ARGON_ATTR_DELIMITER);
        res.append(NUtils::BinToBase64(argonSalt, false)).push_back(ARGON_ATTR_DELIMITER);
        res.append(NUtils::BinToBase64(argonHash, false));

        return res;
    }

    bool TPasswordChecker::CheckMD5Hash(const TString& pwd, const char* hash) {
        TLog::Debug() << "PasswordChecker: MD5Hash was used";
        return Matches(Md5crypt(pwd, hash).c_str(), hash);
    }

    bool TPasswordChecker::CheckArgonWithMD5Hash(const TString& pwd, TStringBuf attrView, const TString& uid) const {
        if (!Argon2_) {
            throw TBlackboxError(TBlackboxError::EType::Unknown) << "Argon is not configured";
        }

        // Read attr
        TStringBuf argonSecretIdView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        if (!argonSecretIdView || !Secrets_.HasKey(argonSecretIdView)) {
            return false;
        }
        const TStringBuf argonSecret = Secrets_.GetKey(argonSecretIdView);

        TStringBuf argonTCostView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        long tCost = 0;
        if (!TryIntFromString<10>(argonTCostView, tCost) || tCost < 1) {
            return false;
        }

        TStringBuf argonMCostView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        long mCost = 0;
        if (!TryIntFromString<10>(argonMCostView, mCost) || mCost < 8) {
            return false;
        }

        TStringBuf md5SaltView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        if (!md5SaltView) {
            return false;
        }
        TString md5salt = "$1$";
        md5salt.append(md5SaltView.begin(), md5SaltView.end()).push_back('$');

        TStringBuf argonSaltView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        if (!argonSaltView) {
            return false;
        }
        TString argonSalt = NUtils::Base64ToBin(argonSaltView);

        TStringBuf argonHashView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        if (!argonHashView) {
            return false;
        }
        TString argonHash = NUtils::Base64ToBin(argonHashView);
        if (attrView) {
            return false;
        }

        // Check
        TString md5hash = Md5crypt(pwd, md5salt.c_str());
        return Argon2_->Get(argonSecret, tCost, mCost).VerifyHash(md5hash, argonHash, argonSalt, uid);
    }

    bool TPasswordChecker::CheckArgonWithRawMD5Hash(const TString& pwd, TStringBuf attrView, const TString& uid) const {
        if (!Argon2_) {
            throw TBlackboxError(TBlackboxError::EType::Unknown) << "Argon is not configured";
        }

        // Read attr
        TStringBuf argonSecretIdView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        if (!argonSecretIdView || !Secrets_.HasKey(argonSecretIdView)) {
            return false;
        }
        const TStringBuf argonSecret = Secrets_.GetKey(argonSecretIdView);

        TStringBuf argonTCostView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        long tCost = 0;
        if (!TryIntFromString<10>(argonTCostView, tCost) || tCost < 1) {
            return false;
        }

        TStringBuf argonMCostView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        long mCost = 0;
        if (!TryIntFromString<10>(argonMCostView, mCost) || mCost < 8) {
            return false;
        }

        TStringBuf argonSaltView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        if (!argonSaltView) {
            return false;
        }
        TString argonSalt = NUtils::Base64ToBin(argonSaltView);

        TStringBuf argonHashView = attrView.NextTok(ARGON_ATTR_DELIMITER);
        if (!argonHashView) {
            return false;
        }
        TString argonHash = NUtils::Base64ToBin(argonHashView);
        if (attrView) {
            return false;
        }

        // Check
        TString md5hash = Md5Raw(pwd);
        return Argon2_->Get(argonSecret, tCost, mCost).VerifyHash(md5hash, argonHash, argonSalt, uid);
    }

    bool TPasswordChecker::Matches(const char* crypted, const char* hash) {
        // compare hashes preventing timing analysis attack
        bool matches = true;

        if (!crypted) {
            // hashing failed but to prevent time analisys attack we still need to make comparison pass
            while (*hash) {
                matches &= (*hash == *hash);
                ++hash;
            }
            return false;
        }
        while (*hash && *crypted) {
            matches &= (*hash++ == *crypted++);
        }

        if (*hash == 0 && *crypted == 0) {
            return matches;
        }

        return false;
    }

}
