#include "tls.h"

#include <balancer/serval/contrib/cone/cone.h>

#include <contrib/libs/openssl/include/openssl/dh.h>
#include <contrib/libs/openssl/include/openssl/err.h>
#include <contrib/libs/openssl/include/openssl/rand.h>
#include <contrib/libs/openssl/include/openssl/ssl.h>
#include <contrib/libs/openssl/include/openssl/x509v3.h>

static const char DefaultCiphers[] =
    "ECDHE-ECDSA-CHACHA20-POLY1305"
    ":ECDHE-RSA-CHACHA20-POLY1305"
    ":ECDHE-ECDSA-AES128-GCM-SHA256"
    ":ECDHE-RSA-AES128-GCM-SHA256"
    ":ECDHE-ECDSA-AES256-GCM-SHA384"
    ":ECDHE-RSA-AES256-GCM-SHA384"
    ":DHE-RSA-AES128-GCM-SHA256"
    ":DHE-RSA-AES256-GCM-SHA384"
    ":ECDHE-ECDSA-AES128-SHA256"
    ":ECDHE-RSA-AES128-SHA256"
    ":ECDHE-ECDSA-AES128-SHA"
    ":ECDHE-RSA-AES256-SHA384"
    ":ECDHE-RSA-AES128-SHA"
    ":ECDHE-ECDSA-AES256-SHA384"
    ":ECDHE-ECDSA-AES256-SHA"
    ":ECDHE-RSA-AES256-SHA"
    ":DHE-RSA-AES128-SHA256"
    ":DHE-RSA-AES128-SHA"
    ":DHE-RSA-AES256-SHA256"
    ":DHE-RSA-AES256-SHA"
    ":ECDHE-ECDSA-DES-CBC3-SHA"
    ":ECDHE-RSA-DES-CBC3-SHA"
    ":EDH-RSA-DES-CBC3-SHA"
    ":AES128-GCM-SHA256"
    ":AES256-GCM-SHA384"
    ":AES128-SHA256:AES256-SHA256"
    ":AES128-SHA:AES256-SHA"
    ":DES-CBC3-SHA"
    ":!DSS:!SSLv3";

// precomputed DH parameters required for perfect forward secrecy suites.
// generated by `openssl dhparam -C 2048`
static const ui8 DhP[] = {
    0x9C,0x0E,0x30,0xC4,0x70,0x1B,0x7B,0x0D,0xA5,0xFD,0xC2,0xAE,
    0x64,0xD6,0xCD,0x4D,0xE7,0xEE,0x7D,0xB3,0x9A,0xE7,0xA7,0xEA,
    0x9B,0x3E,0x60,0xD3,0x15,0x9A,0x78,0xC3,0x77,0x09,0x01,0x53,
    0xED,0xDA,0x2A,0x2C,0x86,0xAA,0x98,0x99,0xEE,0x32,0x02,0x31,
    0x2F,0x6C,0x81,0x83,0x47,0xE3,0xEB,0x5C,0xBC,0xFC,0x95,0x87,
    0x5B,0x6A,0x62,0xEA,0x31,0xB7,0xEF,0x48,0x5B,0xAE,0x39,0xB8,
    0xD2,0x83,0x02,0x51,0xB3,0x0A,0x73,0xA7,0xA1,0xD4,0x75,0x10,
    0x5F,0xBD,0x85,0x9E,0x89,0x4A,0x97,0x80,0x92,0x5F,0x45,0xFA,
    0xEA,0xF5,0x02,0x69,0xBB,0xF2,0x1F,0x16,0xCE,0x23,0x0E,0x5A,
    0xB9,0xDD,0x8F,0x34,0xDB,0x78,0xD1,0x43,0x05,0xA0,0xB2,0xE9,
    0x84,0x68,0x65,0xA1,0x5B,0xB2,0x73,0x8A,0x8B,0xBC,0x2A,0x68,
    0x3C,0xB9,0x29,0x0C,0x3A,0x0E,0x80,0x72,0x6A,0x96,0x87,0xF6,
    0xA6,0x61,0xB8,0xFB,0xA7,0xDC,0xA4,0x0A,0x85,0x1F,0xAD,0x80,
    0x09,0xD6,0xEA,0x0C,0x7A,0xB2,0x01,0x96,0x34,0x07,0xC0,0x23,
    0xC8,0x8B,0xA5,0x36,0xCD,0xC6,0x8D,0x00,0x2C,0x3D,0xE5,0x34,
    0xC6,0xCD,0x77,0x65,0xF6,0xFD,0xFF,0x64,0x07,0xB7,0xAB,0x19,
    0xAD,0xB2,0xD7,0xEC,0xC2,0x80,0x97,0xA5,0x5D,0xDA,0x7B,0x81,
    0x2A,0x49,0x62,0xD4,0x81,0x01,0x13,0x85,0xEB,0x24,0x76,0x16,
    0x48,0x15,0x74,0xAD,0x50,0x8D,0x9D,0x65,0xC7,0xB3,0xE9,0x51,
    0x0C,0xF1,0xB6,0xB9,0x4F,0xF9,0xA2,0x53,0xAA,0x72,0xF7,0x3F,
    0xED,0x19,0xBF,0x40,0xDA,0xC0,0x13,0xBD,0x7E,0x10,0x9F,0xD4,
    0x68,0xA3,0x79,0xFB
};
static const ui8 DhG[] = {0x02};

namespace {
    struct TDestroySSL {
        static inline void Destroy(SSL_CTX* c) noexcept {
            SSL_CTX_free(c);
        }

        static inline void Destroy(SSL* ctx) noexcept {
            SSL_free(ctx);
        }

        static inline void Destroy(DH* key) noexcept {
            DH_free(key);
        }

        static inline void Destroy(BIGNUM* n) noexcept {
            BN_clear_free(n);
        }
    };

    struct TEncryptedFile : NSv::IO {
    public:
        TEncryptedFile(SSL* c) noexcept
            : C_(c)
        {
        }

        operator SSL*() const noexcept {
            return C_.Get();
        }

        TMaybe<size_t> ReadInto(TStringBuf buf) noexcept override {
            return Repeat([&]() {
                return SSL_read(*this, const_cast<char*>(buf.data()), buf.size());
            });
        }

        TMaybe<size_t> Write(TStringBuf buf) noexcept override {
            return Repeat([&]() {
                return SSL_write(*this, buf.data(), buf.size());
            }, 3);
        }

        NSv::IP Peer() const noexcept override {
            NSv::IP ret;
            socklen_t len = sizeof(ret.Data);
            Y_UNUSED(getpeername(SSL_get_fd(*this), &ret.Data.Base, &len));
            return ret;
        }

        TStringBuf SelectedProtocol() const noexcept override {
            const ui8* data = nullptr;
            unsigned len = 0;
            SSL_get0_alpn_selected(*this, &data, &len);
            return {reinterpret_cast<const char*>(data), len};
        }

        template <typename F>
        TMaybe<size_t> Repeat(F&& f, int limit = 100500) noexcept {
            while (limit--) {
                auto code = f();
                if (code > 0) {
                    return code;
                }
                auto err = SSL_get_error(*this, code);
                if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
                    if (cone_iowait(SSL_get_fd(*this), err == SSL_ERROR_WANT_WRITE) MUN_RETHROW) {
                        return {};
                    }
                    continue;
                }
                if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && code == 0)) {
                    return 0;
                }
                if (err == SSL_ERROR_SYSCALL MUN_RETHROW_OS) {
                    return {};
                }
                break;
            }
            mun_error(EEXCEPTION, "%s", ERR_reason_error_string(ERR_get_error()));
            return {};
        }

    private:
        THolder<SSL, TDestroySSL> C_;
    };
}

struct NSv::TTLSContext::TImpl : TNonCopyable {
    const TImpl* Next = nullptr;
    THolder<SSL_CTX, TDestroySSL> C_;
    TString Protocols;
    TVector<std::array<ui8, 48>> TicketKeys;
    bool Server = true;

    operator SSL_CTX*() const noexcept {
        return C_.Get();
    }
};

void NSv::TTLSContext::TDestroy::Destroy(TImpl* d) noexcept {
    delete d;
}

NSv::TTLSContext::TTLSContext(TTLSOptions opts) {
    static int initOpenSSL = []() {
        // TODO switch to OPENSSL_init_ssl
        SSL_load_error_strings();
        SSL_library_init();
        return 0;
    }();
    Y_UNUSED(initOpenSSL);

    I_ = THolder<TImpl, TDestroy>(new TImpl);
    I_->C_ = THolder<SSL_CTX, TDestroySSL>(SSL_CTX_new(opts.Client ? SSLv23_client_method() : SSLv23_server_method()));
    I_->Server = !opts.Client;
    for (const TStringBuf x : opts.Protocols) {
        I_->Protocols.append(static_cast<i8>(x.Size()));
        I_->Protocols.append(x.begin(), x.end());
    }

#define Y_ENSURE_SSL(x) Y_ENSURE(x, ERR_reason_error_string(ERR_get_error()))
    Y_ENSURE_SSL(*I_);
    Y_ENSURE_SSL(SSL_CTX_set_cipher_list(*I_, opts.Ciphers ? opts.Ciphers.c_str() : DefaultCiphers) > 0);
    Y_ENSURE(opts.Client || opts.Certs, "server SSL contexts require at least one certificate");
    for (const auto& cert : opts.Certs) {
        Y_ENSURE_SSL(SSL_CTX_use_PrivateKey_file(*I_, cert.KeyFile.c_str(), SSL_FILETYPE_PEM) > 0);
        Y_ENSURE_SSL(SSL_CTX_use_certificate_chain_file(*I_, cert.CertFile.c_str()) > 0);
        Y_ENSURE_SSL(SSL_CTX_check_private_key(*I_) > 0);
    }
    SSL_CTX_set_verify(*I_, opts.VerifyPeer
        ? SSL_VERIFY_PEER | (!opts.Client && opts.RequireClientCert ? SSL_VERIFY_FAIL_IF_NO_PEER_CERT : 0)
        : SSL_VERIFY_NONE, NULL);
    SSL_CTX_set_options(*I_,
        SSL_OP_NO_SSLv2 | SSL_OP_ALL | SSL_OP_SINGLE_DH_USE | SSL_OP_NO_COMPRESSION |
        SSL_OP_NO_RENEGOTIATION | SSL_OP_CIPHER_SERVER_PREFERENCE);
    SSL_CTX_set_read_ahead(*I_, 1);

    if (opts.Client) {
        SSL_CTX_set_alpn_protos(*I_, reinterpret_cast<const ui8*>(I_->Protocols.data()), I_->Protocols.size());
    } else {
        // XXX in 1.1.1+ (?), the preferred method is SSL_CTX_set_client_hello_cb.
        Y_ENSURE_SSL(SSL_CTX_set_tlsext_servername_arg(*I_, I_.Get()) > 0);
        Y_ENSURE_SSL(SSL_CTX_set_tlsext_servername_callback(*I_, (int(*)(SSL*, int*, void*))
            [](SSL* ssl, int*, void* arg) {
                TStringBuf host = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
                for (const auto* data = static_cast<const TImpl*>(arg); data; data = data->Next) {
                    SSL_set_SSL_CTX(ssl, *data);
                    // (BALANCER-1144) Verification parameters are copied into SSL in SSL_new.
                    SSL_set_verify(ssl, SSL_CTX_get_verify_mode(*data), SSL_CTX_get_verify_callback(*data));
                    SSL_set_verify_depth(ssl, SSL_CTX_get_verify_depth(*data));
                    if (SSL_set_current_cert(ssl, SSL_CERT_SET_FIRST) == 1) {
                        do {
                            if (X509_check_host(SSL_get_certificate(ssl), host.data(), host.size(), 0, nullptr) == 1) {
                                return SSL_TLSEXT_ERR_OK;
                            }
                        } while (SSL_set_current_cert(ssl, SSL_CERT_SET_NEXT) == 1);
                    }
                }
                return SSL_TLSEXT_ERR_NOACK;
            }) > 0);

#if 0
        Y_ENSURE_SSL(SSL_CTX_set_tlsext_status_arg(*I_, I_.Get()) > 0);
        Y_ENSURE_SSL(SSL_CTX_set_tlsext_status_cb(*I_, [](SSL* ssl, void* arg) {
            const X509* cert = SSL_get_certificate(ssl);
            ...
            void* response = OPENSSL_malloc(size);
            memcpy(response, data, size);
            return SSL_set_tlsext_status_ocsp_resp(ssl, response, size) > 0 ? SSL_TLSEXT_ERR_OK : SSL_TLSEXT_ERR_NOACK;
        }) > 0);
#endif

        if (opts.TicketKeys) {
            I_->TicketKeys.assign(opts.TicketKeys.begin(), opts.TicketKeys.end());
        } else {
            I_->TicketKeys.resize(1);
            Y_ENSURE_SSL(RAND_bytes(I_->TicketKeys[0].data(), I_->TicketKeys[0].size()) > 0);
        }
        SSL_CTX_set_timeout(*I_, opts.TicketTTL.Seconds());

        static const int contextEx = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
        Y_ENSURE(contextEx >= 0);
        Y_ENSURE_SSL(SSL_CTX_set_ex_data(*I_, contextEx, I_.Get()) > 0);
        Y_ENSURE_SSL(SSL_CTX_set_tlsext_ticket_key_cb(*I_, (int(*)(SSL*, ui8*, ui8*, EVP_CIPHER_CTX*, HMAC_CTX*, int))
            [](SSL* ssl, ui8* name, ui8* iv, EVP_CIPHER_CTX* ctx, HMAC_CTX* hctx, int enc) {
                const auto* data = static_cast<const TImpl*>(SSL_CTX_get_ex_data(SSL_get_SSL_CTX(ssl), contextEx));
                if (enc == 1) {
                    const auto& keys = data->TicketKeys;
                    if (!keys) {
                        return 0;
                    }
                    const auto& key = keys[0];
                    memcpy(name, &key[0], 16);
                    return RAND_bytes(iv, 16) <= 0 ? -1
                        : EVP_EncryptInit_ex(ctx, EVP_aes_128_cbc(), nullptr, &key[16], iv) <= 0 ? -1
                        : HMAC_Init_ex(hctx, &key[32], 16, EVP_sha256(), nullptr) <= 0 ? -1
                        : 1;
                }
                // The ticket is for one of the contexts in the chain; we don't know which one.
                // SNI callback is triggered *after* this one for some reason.
                for (; data; data = data->Next) {
                    for (const auto& key : data->TicketKeys) {
                        if (!memcmp(&key[0], name, 16)) {
                            return HMAC_Init_ex(hctx, &key[32], 16, EVP_sha256(), nullptr) <= 0 ? -1
                                : EVP_DecryptInit_ex(ctx, EVP_aes_128_cbc(), nullptr, &key[16], iv) <= 0 ? -1
                                : &key == &data->TicketKeys[0] ? 1 : 2;
                        }
                    }
                }
                return 0;
            }) > 0);

        SSL_CTX_set_alpn_select_cb(*I_, [](SSL*, const ui8** out, ui8* olen, const ui8* in, ui32 ilen, void* arg) {
            if (const auto& p = static_cast<const TImpl*>(arg)->Protocols) {
                return SSL_select_next_proto((ui8**)out, olen, (const ui8*)p.data(), p.size(), in, ilen) == OPENSSL_NPN_NEGOTIATED
                    ? SSL_TLSEXT_ERR_OK
                    : SSL_TLSEXT_ERR_ALERT_FATAL;
            }
            return SSL_TLSEXT_ERR_NOACK;
        }, I_.Get());

        THolder<BIGNUM, TDestroySSL> p(BN_bin2bn(DhP, sizeof(DhP), NULL));
        THolder<BIGNUM, TDestroySSL> g(BN_bin2bn(DhG, sizeof(DhG), NULL));
        THolder<DH, TDestroySSL> dh(DH_new());
        Y_ENSURE_SSL(dh && p && g && DH_set0_pqg(dh.Get(), p.Get(), NULL, g.Get()) > 0);
        Y_UNUSED(p.Release()); // now owned by `dh`
        Y_UNUSED(g.Release());
        Y_ENSURE_SSL(SSL_CTX_set_tmp_dh(*I_, dh.Get()) > 0); // copies `dh`
    }
#undef Y_ENSURE_SSL
}

void NSv::TTLSContext::SetNext(const TTLSContext& next) {
    I_->Next = next.I_.Get();
}

THolder<NSv::IO> NSv::TTLSContext::Wrap(int fd, TMaybe<TString> hostname) const noexcept {
// TODO error code differentiation
#define ToMunError(x) ((x) || !mun_error(EEXCEPTION, "%s", ERR_reason_error_string(ERR_get_error())))
    auto io = MakeHolder<TEncryptedFile>(SSL_new(*I_));
    if (hostname) {
        if (!ToMunError(SSL_set1_host(*io, hostname->c_str()) > 0)
         || !ToMunError(SSL_set_tlsext_host_name(*io, hostname->c_str()) > 0)) {
            return {};
        }
    }
    if (!ToMunError(SSL_set_fd(*io, fd)) || !io->Repeat([&]() {
         return I_->Server ? SSL_accept(*io) : SSL_connect(*io);
    }, 3) MUN_RETHROW) {
        return {};
    }
    return io;
#undef ToMunError
}
