#include "ssh_key.h"

#include <contrib/libs/openssl/include/openssl/pem.h>
#include <contrib/libs/openssl/include/openssl/ec.h>

#include <library/cpp/string_utils/base64/base64.h>
#include <util/generic/string.h>
#include <util/generic/strbuf.h>
#include <util/generic/vector.h>
#include <util/generic/yexception.h>
#include <util/stream/file.h>
#include <util/stream/str.h>

namespace NSshKey {
    namespace {
        // borowed from sshbuf.h
        constexpr size_t kOpenSSHBufSizeMax = 0x8000000;

        enum class ESshKeyFileType {
            NONE,
            OPENSSL,
            OPENSSH,
            SECSH
        };

        using TKeyFileHeader = struct {
            ESshKeyFileType type;
            const TString header;
        };

        const TVector<const TKeyFileHeader> kKeyfileHeaders = {
            {ESshKeyFileType::OPENSSL, "-----BEGIN RSA PRIVATE KEY-----"},
            {ESshKeyFileType::OPENSSL, "-----BEGIN DSA PRIVATE KEY-----"},
            {ESshKeyFileType::OPENSSL, "-----BEGIN EC PRIVATE KEY-----"},
            {ESshKeyFileType::OPENSSL, "-----BEGIN ENCRYPTED PRIVATE KEY-----"},
            {ESshKeyFileType::OPENSSL, "-----BEGIN PRIVATE KEY-----"},
            {ESshKeyFileType::OPENSSH, "-----BEGIN OPENSSH PRIVATE KEY-----"},
            {ESshKeyFileType::SECSH, "---- BEGIN SSH2 ENCRYPTED PRIVATE KEY ----"},
        };

        constexpr char kOpenSSHMarkBegin[] = "-----BEGIN OPENSSH PRIVATE KEY-----\n";
        constexpr char kOpenSSHMarkEnd[] = "-----END OPENSSH PRIVATE KEY-----\n";
        constexpr char kOpenSSHAuthMagic[] = "openssh-key-v1";

        inline EKeyType nidToKeytype(int nid) {
            switch (nid) {
                case NID_X9_62_prime256v1:
                    return EKeyType::Ecdsa256;
                case NID_secp384r1:
                    return EKeyType::Ecdsa384;
                case NID_secp521r1:
                    return EKeyType::Ecdsa521;
                default:
                    return EKeyType::Undef;
            }
        }

        inline size_t getUint32(const char* buff) {
            return static_cast<ui32>((unsigned char)buff[0] << 24) |
                   static_cast<ui32>((unsigned char)buff[1] << 16) |
                   static_cast<ui32>((unsigned char)buff[2] << 8) |
                   static_cast<ui32>((unsigned char)buff[3]);
        }

        inline const TString readCString(TStringInput& in) {
            char sizeBuf[4];
            if (Y_UNLIKELY(in.Read(&sizeBuf, 4) != 4)) {
                return nullptr;
            }

            const auto len = getUint32(sizeBuf);
            if (Y_UNLIKELY(len >= kOpenSSHBufSizeMax)) {
                ythrow TSystemError() << "unexpected openssh string size: " << len;
            }

            char result[len];
            if (Y_UNLIKELY(in.Read(&result, len) != len)) {
                return nullptr;
            }

            return TString(result, len);
        }

        inline size_t readCInt(TStringInput& in) {
            char sizeBuf[4];
            if (Y_UNLIKELY(in.Read(&sizeBuf, 4) != 4)) {
                return 0;
            }

            return getUint32(sizeBuf);
        }

        inline EKeyType pubkeyType(const TString& data) {
            TStringInput in(data);
            const auto& typeName = readCString(in);

            if (typeName == "rsa") {
                return EKeyType::Rsa;
            } else if (typeName == "ssh-rsa") {
                return EKeyType::Rsa;
            } else if (typeName == "dsa") {
                return EKeyType::Dsa;
            } else if (typeName == "ssh-dss") {
                return EKeyType::Dsa;
            } else if (typeName == "ecdsa-sha2-nistp256") {
                return EKeyType::Ecdsa256;
            } else if (typeName == "ecdsa-sha2-nistp384") {
                return EKeyType::Ecdsa384;
            } else if (typeName == "ecdsa-sha2-nistp521") {
                return EKeyType::Ecdsa521;
            } else if (typeName == "ssh-ed25519") {
                return EKeyType::Ed25519;
            }

            return EKeyType::Undef;
        }

        class TBioDestroyer {
        public:
            static void Destroy(BIO* bio) {
                BIO_free_all(bio);
            }
        };

        using TBIOHolder = TAutoPtr<BIO, TBioDestroyer>;

        class TEvpKeyDestroyer {
        public:
            static void Destroy(EVP_PKEY* pkey) {
                EVP_PKEY_free(pkey);
            }
        };

        using TEvpKeyHolder = TAutoPtr<EVP_PKEY, TEvpKeyDestroyer>;

    }

    TAutoPtr<TSshKey> TSshKey::FromFile(const TFsPath& path, bool throwError) {
        TFileInput in(path);
        return FromPem(in.ReadAll(), throwError);
    }

    TAutoPtr<TSshKey> TSshKey::FromPem(TStringBuf data, bool throwError) {
        TString key;
        TStringBuf line;
        ESshKeyFileType keyType = ESshKeyFileType::NONE;
        while (keyType == ESshKeyFileType::NONE && data.ReadLine(line)) {
            if (line.empty()) {
                continue;
            }

            for (const auto& header : kKeyfileHeaders) {
                if (header.header == line) {
                    // TODO(buglloc): fix it!
                    key = TString{line} + '\n' + TString{data};
                    keyType = header.type;
                    break;
                }
            }

            if (line.StartsWith("----")) {
                // If we find our marker and it's not one of acceptable - stop iterate, probably we have garbage :/
                break;
            }
        }

        switch (keyType) {
            case ESshKeyFileType::OPENSSL:
                return ParseOpensslPem(key, throwError);
            case ESshKeyFileType::OPENSSH:
                return ParseOpenssh(key, throwError);
            case ESshKeyFileType::SECSH:
                // Can't work with encrypted key
                return nullptr;
            default:
                if (throwError)
                    ythrow TSystemError() << "unknown private key file";
                return nullptr;
        }
    }

    TAutoPtr<TSshKey> TSshKey::ParseOpensslPem(const TStringBuf data, bool throwError) {
        TBIOHolder keyBio = BIO_new_mem_buf(data.data(), static_cast<int>(data.size()));
        if (Y_UNLIKELY(!keyBio)) {
            if (throwError)
                ythrow TSystemError() << "failed to allocate BIO with length: " << data.size();
            return nullptr;
        }

        TEvpKeyHolder evp = PEM_read_bio_PrivateKey(
            keyBio.Get(),
            nullptr,
            // pem_password_cb
            [](char* /*buf*/, int /*size*/, int /*rwflag*/, void * /*u*/) -> int {
                return 0;
            },
            nullptr);

        if (Y_UNLIKELY(!evp)) {
            if (throwError)
                ythrow TSystemError() << "unknown OpenSSL private key file";
            return nullptr;
        }

        TAutoPtr<TPKey> pk = new TPKey();
        switch (EVP_PKEY_id(evp.Get())) {
            case EVP_PKEY_RSA:
                pk->type = EKeyType::Rsa;
                pk->rsa = EVP_PKEY_get1_RSA(evp.Get());
                break;

            case EVP_PKEY_DSA:
                pk->type = EKeyType::Dsa;
                pk->dsa = EVP_PKEY_get1_DSA(evp.Get());
                break;

            case EVP_PKEY_EC:
                pk->ecdsa = EVP_PKEY_get1_EC_KEY(evp.Get());
                {
                    const EC_GROUP* g = EC_KEY_get0_group(pk->ecdsa);
                    pk->type = nidToKeytype(EC_GROUP_get_curve_name(g));
                }
                break;
            default:
                //pass
                break;
        }

        if (pk->type == EKeyType::Undef) {
            if (throwError)
                ythrow TSystemError() << "unknown OpenSSL private key file";
            return nullptr;
        }

        return new TSshKey(pk.Release());
    }

    TAutoPtr<TSshKey> TSshKey::ParseOpenssh(const TStringBuf data, bool throwError) {
        const auto begin = data.find(kOpenSSHMarkBegin) + sizeof(kOpenSSHMarkBegin) - 1;
        const auto end = data.find(kOpenSSHMarkEnd, begin);
        if (!end || !begin) {
            return nullptr;
        }

        TString encodedKey;
        encodedKey.reserve(end - begin);
        for (const auto& ch : data.SubStr(begin, end - begin)) {
            if (ch != '\r' && ch != '\n') {
                encodedKey += ch;
            }
        }

        const auto& decodedKey = Base64Decode(encodedKey);
        if (!decodedKey.StartsWith(kOpenSSHAuthMagic)) {
            if (throwError)
                ythrow TSystemError() << "OpenSSH auth magic not found";
            return nullptr;
        }

        TStringInput in(decodedKey);
        in.Skip(sizeof(kOpenSSHAuthMagic));

        const auto& cipherName = readCString(in);
        if (cipherName != "none") {
            // We doesn't support encrypted keys
            if (throwError)
                ythrow TSystemError() << "encrypted key";
            return nullptr;
        }

        // skip KDF name
        readCString(in);

        // skip KDF options
        readCString(in);
        auto nKeys = readCInt(in);
        if (nKeys == 0 || nKeys > 1) {
            if (throwError)
                ythrow TSystemError() << "only one key supported";
            return nullptr;
        }

        const auto& pubkey = readCString(in);
        if (Y_UNLIKELY(!pubkey)) {
            return nullptr;
        }

        auto type = pubkeyType(pubkey);
        if (type == EKeyType::Undef) {
            return nullptr;
        }

        TAutoPtr<TPKey> pk = new TPKey();
        pk->type = type;
        pk->pubKey = pubkey;
        // TODO(buglloc): need to _parse_ private key?

        return new TSshKey(pk.Release());
    }

}
