#include "ssh_key.h"

#include <contrib/libs/openssl/include/openssl/dsa.h>
#include <contrib/libs/openssl/include/openssl/md5.h>
#include <contrib/libs/openssl/include/openssl/rsa.h>
#include <contrib/libs/openssl/include/openssl/sha.h>

#include <library/cpp/string_utils/base64/base64.h>
#include <util/generic/yexception.h>
#include <util/generic/strbuf.h>
#include <util/stream/str.h>
#include <netinet/in.h>

namespace NSshKey {
    namespace {
        inline void writeCString(IOutputStream& out, const TString& data) {
            const ui32 val = htonl(static_cast<ui32>(data.size()));
            TString size(4, '\x00');
            memcpy(size.begin(), &val, sizeof(val));
            out << size << data;
        }

        inline void writeCBinNum(IOutputStream& out, const BIGNUM* b) {
            TString bStr(static_cast<size_t>(BN_num_bytes(b) + 1), '\x00'); // Extra byte for padding
            BN_bn2bin(b, reinterpret_cast<unsigned char*>(bStr.begin() + 1));
            writeCString(out, bStr[1] & 0x80 ? bStr : bStr.substr(1, bStr.size() - 1));
        }

        inline void writeCEcpoint(IOutputStream& out, const EC_GROUP* curve, const EC_POINT* point) {
            /* Determine length */
            size_t len = EC_POINT_point2oct(curve, point, POINT_CONVERSION_UNCOMPRESSED,
                                     nullptr, 0, nullptr);

            TString bStr(len, '\x00');
            if (EC_POINT_point2oct(curve, point, POINT_CONVERSION_UNCOMPRESSED,
                                   reinterpret_cast<unsigned char*>(bStr.begin()), len, nullptr) == len)
            {
                writeCString(out, bStr);
            }
        }

        inline const TString CurveKeytypeName(const EKeyType type) {
            switch (type) {
                case EKeyType::Ecdsa256:
                    return "nistp256";
                case EKeyType::Ecdsa384:
                    return "nistp384";
                case EKeyType::Ecdsa521:
                    return "nistp521";
                default:
                    ythrow TSystemError() << "unknown curve type";
            }
        }

        inline const TString SshMd5(const char* data, size_t len) {
            unsigned char digest[MD5_DIGEST_LENGTH] = {0};
            MD5(reinterpret_cast<const unsigned char*>(data), len, digest);
            static const char hex[] = "0123456789abcdef";
            char buf[MD5_DIGEST_LENGTH * 3] = {0};
            for (int i = 0; i < MD5_DIGEST_LENGTH; ++i) {
                buf[i + i + i] = hex[digest[i] >> 4];
                buf[i + i + i + 1] = hex[digest[i] & 0x0f];
                buf[i + i + i + 2] = ':';
            }
            // Remove the trailing ':' character
            buf[sizeof(buf) - 1] = '\0';
            return TString(buf);
        }

        inline const TString SshSha256(const char* data, size_t len) {
            TString dStr(SHA256_DIGEST_LENGTH, '\x00');
            SHA256(reinterpret_cast<const unsigned char*>(data), len, reinterpret_cast<unsigned char*>(dStr.begin()));

            const auto& base64 = Base64Encode(dStr);
            // Remove padding
            return base64.substr(0, base64.size() - 1);
        }

    }

    const TString TSshKey::FingerprintLegacy() const {
        const auto& blob = pkey->PubKey();
        return SshMd5(blob.data(), blob.size());
    }

    const TString TSshKey::Fingerprint() const {
        const auto& blob = pkey->PubKey();
        return "SHA256:" + SshSha256(blob.data(), blob.size());
    }

    TString TSshKey::TPKey::TypeName() const {
        switch (type) {
            case EKeyType::Rsa:
                return "ssh-rsa";
            case EKeyType::Dsa:
                return "ssh-dss";
            case EKeyType::Ecdsa256:
                return "ecdsa-sha2-nistp256";
            case EKeyType::Ecdsa384:
                return "ecdsa-sha2-nistp384";
            case EKeyType::Ecdsa521:
                return "ecdsa-sha2-nistp521";
            case EKeyType::Ed25519:
                return "ssh-ed25519";
            default:
                ythrow TSystemError() << "unknown Key type";
        }
    }

    TString TSshKey::TPKey::PubKey() {
        if (pubKey) {
            return pubKey;
        }

        TStringStream result;

        switch (type) {
            case EKeyType::Rsa: {
                writeCString(result, TypeName());
                writeCBinNum(result, RSA_get0_e(rsa));
                writeCBinNum(result, RSA_get0_n(rsa));
                break;
            }
            case EKeyType::Dsa:
                writeCString(result, TypeName());
                writeCBinNum(result, DSA_get0_p(dsa));
                writeCBinNum(result, DSA_get0_q(dsa));
                writeCBinNum(result, DSA_get0_g(dsa));
                writeCBinNum(result, DSA_get0_pub_key(dsa));
                break;
            case EKeyType::Ecdsa256:
            case EKeyType::Ecdsa384:
            case EKeyType::Ecdsa521:
                writeCString(result, TypeName());
                writeCString(result, CurveKeytypeName(type));
                writeCEcpoint(result, EC_KEY_get0_group(ecdsa), EC_KEY_get0_public_key(ecdsa));
                break;
            case EKeyType::Ed25519:
                break;
            default:
                ythrow TSystemError() << "unknown key type";
        }

        pubKey = result.Str();
        return pubKey;
    }

}
