#include "tls_key.h"

#include <contrib/libs/openssl/include/openssl/bio.h>
#include <contrib/libs/openssl/include/openssl/pem.h>
#include <contrib/libs/openssl/include/openssl/x509v3.h>
#include <contrib/libs/openssl/include/openssl/evp.h>

#include <util/stream/file.h>
#include <util/folder/filelist.h>

namespace NTlsKey {
    namespace {
        class TBioDestroyer {
        public:
            static void Destroy(BIO* bio) {
                BIO_free_all(bio);
            }
        };

        using TBIOHolder = TAutoPtr<BIO, TBioDestroyer>;

        class TEvpKeyDestroyer {
        public:
            static void Destroy(EVP_PKEY* pkey) {
                EVP_PKEY_free(pkey);
            }
        };

        using TEvpKeyHolder = TAutoPtr<EVP_PKEY, TEvpKeyDestroyer>;

        class TAsnTimeDestroyer {
        public:
            static void Destroy(ASN1_TIME* time) {
                ASN1_STRING_free(time);
            }
        };

        using TAsn1TimeHolder = TAutoPtr<ASN1_TIME, TAsnTimeDestroyer>;

        class TCertStackDestroyer {
        public:
            static void Destroy(STACK_OF(X509_INFO) * certStack) {
                sk_X509_INFO_pop_free(certStack, X509_INFO_free);
            }
        };

        using TCertStackHolder = TAutoPtr<STACK_OF(X509_INFO), TCertStackDestroyer>;

        inline int noPassCb(char* /*buf*/, int /*size*/, int /*rwflag*/, void* /*u*/) {
            return 0;
        }

        inline TString getCertSerial(X509* x509) {
            auto asn1Serial = X509_get_serialNumber(x509);
            if (asn1Serial == nullptr) {
                return nullptr;
            }

            if (asn1Serial->type == V_ASN1_NEG_INTEGER) {
                // RFC5280: The serial number MUST be a positive integer assigned by the CA to each certificate
                return nullptr;
            }

            static const char hex[] = "0123456789ABCDEF";
            TString result(static_cast<size_t>(asn1Serial->length * 2), '\x00');
            for (auto i = 0; i < asn1Serial->length; ++i) {
                result[i * 2] = hex[asn1Serial->data[i] >> 4];
                result[i * 2 + 1] = hex[asn1Serial->data[i] & 0x0f];
            }
            return result;
        }

        inline TString getNidValue(X509_NAME* name, int nid) {
            auto pos = X509_NAME_get_index_by_NID(name, nid, -1);

            if (Y_UNLIKELY(pos == -1)) {
                return TString();
            }

            auto entry = X509_NAME_get_entry(name, pos);
            auto data = X509_NAME_ENTRY_get_data(entry);
            return TString(reinterpret_cast<const char*>(ASN1_STRING_get0_data(data)));
        }

        inline bool isCertExpired(X509* x509) {
            ASN1_TIME* notAfter = X509_get_notAfter(x509);
            return X509_cmp_current_time(notAfter) < 0;
        }

        inline X509* getFirstCert(STACK_OF(X509_INFO) * certStack) {
            for (auto i = 0; i < sk_X509_INFO_num(certStack); ++i) {
                X509* x509 = sk_X509_INFO_value(certStack, i)->x509;
                if (x509) {
                    return x509;
                }
            }
            return nullptr;
        }

        inline TCertStackHolder searchCertStack(EVP_PKEY* pkey, const TFsPath& path) {
            TFileList fl;

            const auto& dir = TFsPath(path.Dirname());
            fl.Fill(dir);
            const auto& selfName = path.GetName();
            TStringBuf name;
            while ((name = fl.Next())) {
                if (name == selfName) {
                    continue;
                }

                if (!name.EndsWith(TStringBuf(".crt")) && !name.EndsWith(TStringBuf(".pem")) && !name.EndsWith(TStringBuf(".cert"))) {
                    continue;
                }

                TFileInput in(dir / name);
                TBIOHolder certBio = BIO_new(BIO_s_mem());
                if (Y_UNLIKELY(!certBio)) {
                    ythrow TSystemError() << "failed to allocate BIO";
                }

                const auto& data = in.ReadAll();
                BIO_write(certBio.Get(), data.data(), static_cast<int>(data.size()));
                TCertStackHolder certstack = PEM_X509_INFO_read_bio(certBio.Get(), nullptr, noPassCb, nullptr);
                if (Y_UNLIKELY(!certstack) || sk_X509_INFO_num(certstack.Get()) == 0) {
                    // Skip empty or invalid
                    continue;
                }

                const auto x509 = getFirstCert(certstack.Get());
                if (!x509) {
                    // no certificates in certstack
                    continue;
                }

                if (X509_check_private_key(x509, pkey)) {
                    return certstack;
                }
            }

            return nullptr;
        }

    }

    TAutoPtr<TTlsKey> TTlsKey::FromFile(const TFsPath& path, bool throwError) {
        TFileInput in(path);
        return FromPem(in.ReadAll(), path, throwError);
    }

    TAutoPtr<TTlsKey> TTlsKey::FromPem(TStringBuf data, const TFsPath& path, bool throwError) {
        TBIOHolder certBio = BIO_new(BIO_s_mem());
        if (Y_UNLIKELY(!certBio)) {
            ythrow TSystemError() << "failed to allocate BIO";
        }
        BIO_write(certBio.Get(), data.data(), static_cast<int>(data.size()));

        TEvpKeyHolder evp = PEM_read_bio_PrivateKey(
            certBio.Get(),
            nullptr,
            noPassCb,
            nullptr);

        if (!evp) {
            return nullptr;
        }

        bool certFound = false;
        TCertStackHolder certStack = PEM_X509_INFO_read_bio(certBio.Get(), nullptr, noPassCb, nullptr);
        if (certStack && sk_X509_INFO_num(certStack.Get()) > 0) {
            const auto x509 = getFirstCert(certStack.Get());
            if (x509) {
                certFound = true;
                if (!X509_check_private_key(x509, evp.Get())) {
                    // Something terrible!
                    return new TTlsKey();
                }
            }
        }

        if (!certFound) {
            if (path) {
                certStack = searchCertStack(evp.Get(), path);
                if (certStack) {
                    certFound = true;
                }
            }

            if (!certFound) {
                // return empty TLS Key w/o certificates
                return new TTlsKey();
            }
        }

        const auto& key = FromCertStack(certStack.Get(), throwError);
        return key.Release();
    }

    TAutoPtr<TTlsKey> TTlsKey::FromCertStack(const stack_st_X509_INFO* certStack, bool /*throwError*/) {
        TAutoPtr<TTlsKey> key = new TTlsKey();

        X509_INFO* stack_item = nullptr;
        int last = sk_X509_INFO_num(certStack) - 1;
        for (auto i = 0; i < sk_X509_INFO_num(certStack); ++i) {
            stack_item = sk_X509_INFO_value(certStack, i);
            if (!stack_item->x509) {
                // Not x509 cert
                continue;
            }

            key->haveCert = true;
            const auto x509 = stack_item->x509;

            if (i == 0) {
                // Get some info about first certificate in chain
                key->serial = getCertSerial(x509);
                key->selfSigned = X509_NAME_cmp(X509_get_subject_name(x509), X509_get_issuer_name(x509)) == 0;
                key->clientAuth = X509_check_purpose(x509, X509_PURPOSE_SSL_CLIENT, 0) == 1;
                key->serverAuth = X509_check_purpose(x509, X509_PURPOSE_SSL_SERVER, 0) == 1;
                char buf[1024];
                char* subj = X509_NAME_oneline(X509_get_subject_name(x509), buf, sizeof(buf));
                if (subj) {
                    key->subject = TString(subj);
                }
            }

            // Deal any expired cert in chain as expired
            if (!key->expired) {
                key->expired = isCertExpired(x509);
            }

            X509_NAME* subj = X509_get_subject_name(x509);
            key->chain.push_back(getNidValue(subj, NID_commonName));

            if (i == last) {
                X509_NAME* issuer = X509_get_issuer_name(x509);
                key->chain.push_back(getNidValue(issuer, NID_commonName));
            }
        }

        return key.Release();
    }

}
