#pragma once

#include "sslextdataindeces.h"
#include "protocol_version.h"

#include <balancer/kernel/io/iobase.h>
#include <balancer/kernel/memory/chunks.h>
#include <balancer/kernel/module/iface.h>
#include <balancer/kernel/http/parser/client_proto.h>

#include <util/datetime/base.h>
#include <util/generic/hash.h>
#include <util/generic/maybe.h>
#include <util/generic/string.h>
#include <util/generic/yexception.h>
#include <util/string/hex.h>

#include <openssl/tls1.h>

#include <utility>

struct ssl_ctx_st;
struct ssl_st;
struct bio_st;
struct evp_cipher_ctx_st;
struct hmac_ctx_st;
struct ssl_method_st;
struct x509_st;

namespace NSrvKernel {

constexpr const int SSL3_RT_MIN_PLAIN_LENGTH = 512;
constexpr const unsigned char BALANCER_SESSION_CTX_ID[] = "balancer session ctx id";

class TSslIo;

void InitSsl() noexcept;

int GetLastSslError() noexcept;
const char* SslErrorText(int error) noexcept;

namespace NPrivate {

struct TX509Deleter {
    static void Destroy(x509_st* p) noexcept;
};

}  // namespace NPrivate

constexpr int TLS_TICKET_NAME_LENGTH = 16;
using TX509Holder = THolder<x509_st, NPrivate::TX509Deleter>;


class IAlpnProtos {
public:
    virtual ~IAlpnProtos() = default;

    virtual bool HasH2(const NAddr::IRemoteAddr*, bool expEnabled) const noexcept = 0;
};

namespace NPrivate {

struct TSslContextAccessor;

}  // namespace NPrivate

class TSslContext : public TMoveOnly {
private:
    friend struct NPrivate::TSslContextAccessor;

    explicit TSslContext(ssl_ctx_st* ctx)
        : Ctx_(ctx)
    {}

    TSslContext() noexcept = default;

public:
    TSslContext(TSslContext&& rhs) noexcept = default;
    TSslContext& operator=(TSslContext&& rhs) noexcept = default;

    ssl_ctx_st* Ctx() const noexcept {
        return Ctx_.Get();
    }

    void EnableSslTicketKeys() noexcept;
    void DisableSslTicketKeys() noexcept;

    void EnableH2() noexcept {
        HasH2_ = true;
    }

    const x509_st* GetFirstCert() const noexcept;
    const x509_st* GetNextCert() const noexcept;

    // Error may be checked with GetLastSslError() and SslErrorText()
    TErrorOr<bool> LoadVerifyLocations(const char* caFile);
    TError SetVerifyMode(int mode) noexcept;
    TError SetVerifyPeer() noexcept;
    TError SetVerifyNone() noexcept;
    TError SetVerifyDepth(int depth);
    TError SetClientCaFile(const char* ca);
    TError SetClientCrlFile(const char* crlFileName);

    void SetAlpnProtos(const IAlpnProtos* alpnProtos) noexcept {
        AlpnProtos_ = alpnProtos;
    }

protected:
    bool HasH2(const NAddr::IRemoteAddr* addr, bool expEnabled) const noexcept {
        if (AlpnProtos_) {
            return AlpnProtos_->HasH2(addr, expEnabled);
        }
        return HasH2_;
    }

    TX509Holder& SecondaryCert() noexcept {
        return SecondaryCert_;
    }

private:
    struct TDestroy {
        static void Destroy(ssl_ctx_st* ctx) noexcept;
    };

    THolder<ssl_ctx_st, TDestroy> Ctx_;
    TX509Holder PrimaryCert_;
    TX509Holder SecondaryCert_;
    const IAlpnProtos* AlpnProtos_ = nullptr;
    bool HasH2_ = false;
};

class TSslServerContext final : public TSslContext, TNonCopyable {
private:
    friend struct NPrivate::TSslContextAccessor;

    explicit TSslServerContext(TSslContext&& ctx)
        : TSslContext(std::move(ctx))
    {}

public:
    using TServernameCallback = ssl_ctx_st* (*)(void* data, const TStringBuf& hostname);
    using TOcspResponseCallback = const TChunk* (*)(void*, ssl_st*);
    using TSslTicketKeysCallback = int (*)(void*, ssl_st* ssl, unsigned char* name, unsigned char* iv,
                                           evp_cipher_ctx_st* cipherCtx, hmac_ctx_st* hmacCtx, int init);
    using TSslKeylogCallback = void (*)(void*, const TStringBuf);

    TError AddSecondaryCert(const char* keyFile, const char* certFile, const char* caFile, bool validateCertDate);

    template <ssl_ctx_st* (*F)(void* data, const TStringBuf& hostname)>
    TError SetServernameCallback(void* data) {
        return SetServernameCallbackImpl(&ServernameCallback<F>, data);
    }

    template <const TChunk* (*F)(void*, ssl_st*)>
    TError SetOcspResponseCallback(void* data) {
        return SetOcspResponseCallbackImpl(&OcspResponseCallback<F>, data);
    }

    /**
     * Set tickets callback and store pointer to current TSslItem
     *
     * @tparam F                         Function with main callback code
     *
     * @param[in] data                   Pointer to TSslItem
     *
     * @return                           Error
     */
    template <TSslTicketKeysCallback F>
    TError SetTicketKeysCallback(void* data) {
        // Wrapper and template here to avoid circular dependencies in includes
        return SetTicketKeysCallbackImpl(&SslTicketKeysCallback<F>, data);
    }

    /**
     * Set keylog callback and store pointer to current TSslItem
     *
     * @tparam F                         Function with main callback code
     *
     * @param[in] data                   Pointer to TSslItem
     *
     * @return                           Error
     */
    template <TSslKeylogCallback F>
    void SetKeylogCallback(void* data) {
        SetKeylogCallbackImpl(&KeylogCallback<F>, data);
    }

    void SetTimeout(TDuration timeout) noexcept;

    void SetMaxSendFragment(long maxSendFragment) noexcept;

private:
    TError InitServerContext(const char* curves, bool ja3Enabled);

    template <TServernameCallback F>
    static int ServernameCallback(ssl_st* ssl, int*, void* data) noexcept {
        TStringBuf servername = GetServername(ssl);
        return ServernameRoutine(ssl, F(data, servername));
    }

    template <TOcspResponseCallback F>
    static int OcspResponseCallback(ssl_st* ssl, void* data) {
        int ret;
        // TODO(tender-bum): return TError
        TryRethrowError(OcspResponseRoutine(ssl, F(data, ssl)).AssignTo(ret));
        return ret;
    }

    /**
     * Wrapper for main callback function outside this module.
     *
     * @tparam F                         Function with main callback code
     *
     * @param[in] name                   key name from TSslTicketKey
     * @param[in] iv                     initialization vector
     * @param[in] cipherCtx              AES key for EVP
     * @param[in] hmacCtx                SHA256 key for HMAC
     * @param[in] init                   renew key state
     *
     * @return                           tickets state
     */
    template <TSslTicketKeysCallback F>
    static int SslTicketKeysCallback(ssl_st* ssl, unsigned char* name, unsigned char* iv,
        evp_cipher_ctx_st* cipherCtx, hmac_ctx_st* hmacCtx, int init)
    {
        // See OpenSSL documentation:
        // https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_tlsext_ticket_key_cb.html
        return F(GetSslTicketKeysData(GetSslCtx(ssl)), ssl, name, iv, cipherCtx, hmacCtx, init);
    }

    /**
     * Wrapper for main callback function outside this module.
     *
     * @tparam F                         Function with main callback code
     *
     * @param[in] ssl                    ssl context
     * @param[in] line                   keylog line from openssl
     */
    template <TSslKeylogCallback F>
    static void KeylogCallback(const ssl_st* ssl, const char* line)
    {
        // See OpenSSL documentation:
        // https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_get_keylog_callback.html
        F(GetSslKeylogData(GetSslCtx(ssl)), line);
    }

    static TStringBuf GetServername(ssl_st* s) noexcept;
    static int ServernameRoutine(ssl_st* ssl, ssl_ctx_st* ctx) noexcept;
    static TErrorOr<int> OcspResponseRoutine(ssl_st* ssl, const TChunk* response);

    using TServernameImplCallback = int (*)(ssl_st*, int*, void*);
    using TOcspResponseImplCallback = int (*)(ssl_st*, void*);
    using TTicketKeysImplCallback = int (*)(ssl_st*, unsigned char*, unsigned char*,
                                            evp_cipher_ctx_st*, hmac_ctx_st*, int);
    using TSetKeylogImplCallback = void (*)(const ssl_st *, const char *);

    TError SetServernameCallbackImpl(TServernameImplCallback callback, void* data);
    TError SetOcspResponseCallbackImpl(TOcspResponseImplCallback callback, void* data);
    /**
     * Wrapper for tickets callback setter
     *
     * @param[in] callback               Pointer to callback function from outside module
     * @param[in] data                   Pointer to TSslItem
     *
     * @return                           Error
     */
    TError SetTicketKeysCallbackImpl(TTicketKeysImplCallback callback, void* data);

    /**
     * Wrapper for NSS keylogger
     *
     * @param[in] callback               Pointer to callback function from outside module
     * @param[in] data                   Pointer to TSslItem
     * @return                           Error
     */
    void SetKeylogCallbackImpl(TSetKeylogImplCallback callback, void* data);

    static int AlpnCb(ssl_st* ssl, const unsigned char** out, unsigned char* outlen,
                      const unsigned char* in, unsigned int inlen, void* arg);
    /**
     * Early stage client hello callback for extraction of ja3 info.
     *
     * @param[in] ssl                    SSL state
     * @param[in] al                     Unused
     * @param[out] arg                   SSL server ctx with pointer to TSslIO class
     * @return                           SSL_CLIENT_HELLO_SUCCESS(1)
     */
    static int HelloCb(ssl_st* s, int* al, void* arg);

};

class TSslClientContext final : public TSslContext {
private:
    friend struct NPrivate::TSslContextAccessor;

    explicit TSslClientContext(TSslContext&& ctx)
        : TSslContext(std::move(ctx))
    {}
};

/*
 * Holder ssl_ctx init parameters
 */
struct TSslServerCtxInitParams {
    ui64 Protocols = 0;
    const char* Ciphers = nullptr;
    const char* Suites = nullptr;
    const char* Curves = nullptr;
    const char* KeyFile = nullptr;
    const char* CertFile = nullptr;
    const char* CaFile = nullptr;
    bool Ja3Enabled = false;
    bool ValidateCertDate = false;

    TSslServerCtxInitParams(ui64 protocols, const char* ciphers, const char* suites, const char* curves, const char* keyFile,
                            const char* certFile, const char* caFile, bool ja3Enabled, bool validateCertDate)
        : Protocols(protocols)
        , Ciphers(ciphers)
        , Suites(suites)
        , Curves(curves)
        , KeyFile(keyFile)
        , CertFile(certFile)
        , CaFile(caFile)
        , Ja3Enabled(ja3Enabled)
        , ValidateCertDate(validateCertDate)
    {};
};

namespace NPrivate {

struct TSslContextAccessor {
    static TErrorOr<TSslContext> CreateSslCtx(const ssl_method_st* method, ui64 protocols);
    static TErrorOr<THolder<TSslServerContext>> CreateServerCtx(TSslServerCtxInitParams& params);
    static TSslClientContext CreateClientCtx(const char* ciphers, const char* suites);
};

}  // namespace NPrivate

// TODO(tender-bum)
// We're forced to use THolder there because address of TSslServerContext is used for some ssl
// callbacks. So move of that object will invalidate original pointer. Anyway, this method is
// used in only one place in whole balancer and TSslServerContext is stored in THolder in any case.
inline TErrorOr<THolder<TSslServerContext>> CreateSslServerCtx(TSslServerCtxInitParams& params)
{
    return NPrivate::TSslContextAccessor::CreateServerCtx(params);
}

[[nodiscard]] inline TSslClientContext CreateSslClientCtx(const char* ciphers, const char* suites) {
    return NPrivate::TSslContextAccessor::CreateClientCtx(ciphers, suites);
}

/*
 * Helper for storing TSslIo pointer in ssl_st
 *
 */

struct TSslIoIndex {
    TSslIoIndex()
        : Idx(SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr))
    {
        if (Idx == -1) {
            Y_FAIL("SSL_CTX_get_ex_new_index");
        }
    }

    const int Idx;
};

class TSslIo final : public IIoInput, public IHttpOutput {
public:
    TSslIo(
        const TSslContext& ctx,
        IIoInput& input,
        IIoOutput& output,
        const NAddr::IRemoteAddr* addr = nullptr,
        bool expEnabled = false,
        bool cpuLimiterEnabled = true,
        const TSslEarlyDataParams& earlyDataParams = {false, 0, 0}
    );

    ~TSslIo() override;

    TError Connect(TInstant deadline = TInstant::Max());
    TError Accept();
    void Shutdown() noexcept;
    int GetShutdown() noexcept;
    // For use with clients only
    TError SetSniServername(const char* serverName);

    TError SetClientAlpn(const ui8 *protocols, size_t sz);

    // call on established connections only
    TString CipherName() const;
    /**
     * Get current OpenSSL-specific cipher id
     *
     * @return                           cipher id
     */
    ui16 CipherId() const;
    TString ClientRandom() const;
    TString ServerRandom() const;
    TString MasterSecret() const;
    TString SessionId() const;
    TString SslVersion() const;
    EClientProto AlpnProto() const;
    NSsl::EProtocolVersion GetProtocolVersion() const;

private:
    TMaybe<TString> ClientCertField(TStringBuf fieldName) const;

public:
    TMaybe<TString> ClientCertCN() const;
    TMaybe<TString> ClientCertSubject() const;
    TMaybe<TString> ClientCertSerialNumber() const;
    TMaybe<long> ClientCertVerifyResult() const;
    /**
     * Check if client certificate was present
     *
     * @return                           client cert state
     */
    bool ClientCertPresent() const;

    const ssl_ctx_st* CtxRaw() const noexcept;

public:
    const NAddr::IRemoteAddr* RemoteAddr() const {
        return RemoteAddr_;
    }

    bool ExpEnabled() const noexcept {
        return ExpEnabled_;
    }

    bool CpuLimiterEnabled() const noexcept {
        return CpuLimiterEnabled_;
    }

    IIoInput& Input() const noexcept {
        return *Input_;
    }

    IIoOutput& Output() const noexcept {
        return *Output_;
    }

    // for internal use only
    void SetIo(IIoInput* input, IIoOutput* output) noexcept {
        Input_ = input;
        Output_ = output;
    }

    TInstant RecvDeadline() const noexcept {
        return RecvDeadline_;
    }

    TInstant SendDeadline() const noexcept {
        return SendDeadline_;
    }

    bool UsedTlsTickets() const noexcept {
        Y_VERIFY(Established_, "Connection not established");
        return UsedTlsTickets_;
    }

    void SetUsedTlsTickets() noexcept {
        UsedTlsTickets_ = true;
    }

    TString TicketName() const noexcept  {
        return TicketName_;
    }

    void SetTicketName(unsigned char* name) {
        TicketName_ = HexEncode(reinterpret_cast<char*>(name), TLS_TICKET_NAME_LENGTH);
    }

    TString TicketIV() const noexcept {
        return TicketIV_;
    }

    void SetTicketIV(unsigned char* iv) {
        TicketIV_ = HexEncode(reinterpret_cast<char*>(iv), EVP_MAX_IV_LENGTH);
    }

    bool HasStoredError() const noexcept {
        return !!StoredError;
    }

    void StoreError(TError error) noexcept {
        StoredError = std::move(error);
    }

    // ja3 routines

    /**
     * Copy client extensions field
     *
     * @param[in] int*                   raw extensions list
     * @param[in] sz                     size of array
     */
    void SetClientExtensions(int* ext, size_t sz) noexcept {
        if (ext != nullptr && sz > 0) {
            ClientExtensions_.assign(ext, ext + sz);
        }
    }

    /**
     * Copy client ciphers field
     *
     * @param[in] ui8*                   raw ciphers list
     * @param[in] sz                     size of array
     */
    void SetCiphers(const ui8* ciphers, size_t sz) noexcept {
        if (ciphers != nullptr && sz > 0) {
            Ciphers_.assign(ciphers, ciphers + sz);
        }
    }

    /**
     * Copy client elliptic curves field
     *
     * @param[in] ui8*                   raw curves list
     * @param[in] sz                     size of array
     */
    void SetEllipticCurves(const ui8 *curves, size_t sz) noexcept {
        if (curves != nullptr && sz > 0) {
            EllipticCurves_.assign(curves, curves + sz);
        }
    }

    /**
     * Copy client elliptic curves formats field
     *
     * @param[in] int                    raw curves formats list
     * @param[in] sz                     size of array
     */
    void SetEllipticCurvesPointFormats(const ui8 *curves, size_t sz) noexcept {
        if (curves != nullptr && sz > 0) {
            EllipticCurvesPointFormats_.assign(curves, curves + sz);
        }
    }

    /**
     * Copy client signature algorithms field
     *
     * @param[in] ui8*                   raw signature algorithms list
     * @param[in] sz                     size of array
     */
    void SetSignatureAlgorithms(const ui8 *algorithms, size_t sz) noexcept {
        if (algorithms != nullptr && sz > 0) {
            SignatureAlgorithms_.assign(algorithms, algorithms + sz);
        }
    }

    /**
     * Copy client certificates signature algorithms field
     *
     * @param[in] ui8*                   raw certificates signature algorithms list
     * @param[in] sz                     size of array
     */
    void SetSignatureAlgorithmsCert(const ui8 *algorithms, size_t sz) noexcept {
        if (algorithms != nullptr && sz > 0) {
            SignatureAlgorithmsCert_.assign(algorithms, algorithms + sz);
        }
    }

    /**
     * Copy client supported TLS versions field
     *
     * @param[in] ui8*                   raw supported TLS versions list
     * @param[in] sz                     size of array
     */
    void SetSupportedVersions(const ui8 *versions, size_t sz) noexcept {
        if (versions != nullptr && sz > 0) {
            SupportedVersions_.assign(versions, versions + sz);
        }
    }

    /**
     * Copy client list of protocols
     *
     * @param[in] ui8*                   raw list of protocols advertised by client
     * @param[in] sz                     size of array
     */
    void SetApplicationLayerProtocolNegotiation(const ui8 *protocols, size_t sz) noexcept {
        if (protocols != nullptr && sz > 0) {
            ApplicationLayerProtocolNegotiation_.assign(protocols, protocols + sz);
        }
    }

    /**
     * Copy client key shares
     *
     * @param[in] ui8*                   raw list of key shares
     * @param[in] sz                     size of array
     */
    void SetKeyShare(const ui8 *keys, size_t sz) noexcept {
        if (keys != nullptr && sz > 0) {
            KeyShare_.assign(keys, keys + sz);
        }
    }

    /**
     * Copy client pre-shared key exchange modes
     *
     * @param[in] ui8*                   raw list of pre-shared key exchange modes
     * @param[in] sz                     size of array
     */
    void SetPskKeyExchangeModes(const ui8 *modes, size_t sz) noexcept {
        if (modes != nullptr && sz > 0) {
            PskKeyExchangeModes_.assign(modes, modes + sz);
        }
    }

    /**
     * Copy TLS client version
     *
     * @param[in] ui32                   client tls version
     */
    void SetLegacyVersion(ui32 version) noexcept {
        LegacyVersion_ = version;
    }

    /**
     * Get TLS client version
     *
     * @return                           client version
     */
    ui32 LegacyVersion() const noexcept {
        return LegacyVersion_;
    }

    /**
     * Get ciphers list
     *
     * @return                           constant pointer to vector of ciphers
     */
    const TVector<ui8>& Ciphers() const noexcept {
        return Ciphers_;
    }

    /**
     * Get client extensions list
     *
     * @return                           constant pointer to vector of client extensions
     */
    const TVector<int>& ClientExtensions() const noexcept {
        return ClientExtensions_;
    }

    /**
     * Get elliptic curves list
     *
     * @return                           constant pointer to vector of elliptic curves
     */
    const TVector<ui8>& EllipticCurvers() const noexcept {
        return EllipticCurves_;
    }

    /**
     * Get elliptic curves formats list
     *
     * @return                           constant pointer to vector of elliptic curves formats
     */
    const TVector<ui8>& EllipticCurversPointFormats() noexcept {
        return EllipticCurvesPointFormats_;
    }

    /**
     * Get signature algorithms list
     *
     * @return                           constant pointer to vector of signature algorithms
     */
    const TVector<ui8>& SignatureAlgorithms() const noexcept {
        return SignatureAlgorithms_;
    }

    /**
     * Get certificates signature algorithms list
     *
     * @return                           constant pointer to vector of certificates signature algorithms
     */
    const TVector<ui8>& SignatureAlgorithmsCert() const noexcept {
        return SignatureAlgorithmsCert_;
    }

    /**
     * Get supported TLS versions list
     *
     * @return                           constant pointer to vector of supported TLS versions
     */
    const TVector<ui8>& SupportedVersions() const noexcept {
        return SupportedVersions_;
    }

    /**
     * Get list of protocols advertised by client
     *
     * @return                           constant pointer to vector of protocols advertised by client
     */
    const TVector<ui8>& ApplicationLayerProtocolNegotiation() const noexcept {
        return ApplicationLayerProtocolNegotiation_;
    }

    /**
     * Get key shares list
     *
     * @return                           constant pointer to vector of key shares
     */
    const TVector<ui8>& KeyShare() const noexcept {
        return KeyShare_;
    }

    /**
     * Get pre-shared key exchange modes list
     *
     * @return                           constant pointer to vector of pre-shared key exchange modes
     */
    const TVector<ui8>& PskKeyExchangeModes() const noexcept {
        return PskKeyExchangeModes_;
    }

    /**
     * Get early data state
     *
     * @return                           True if early data was accepted
     */
    bool EarlyData() const noexcept{
        return EarlyDataFinished_;
    }
private:
    TError DoRecv(TChunkList& lst, TInstant deadline) noexcept override;

    TError DoSendHead(TResponse&&, const bool, TInstant) override {
        Y_FAIL("ssl module doesn't know what to do in SendHead.");
    }

    TError DoSendTrailers(THeaders&&, TInstant) override {
        Y_FAIL("ssl module doesn't know what to do in SendTrailers.");
    }

    TError DoSend(TChunkList lst, TInstant deadline) noexcept override;

    /*
     * Store TSslIo in ssl_st to use it in openssl's callbacks
     */
    void RegisterSslIo() noexcept;
private:
    struct TDestroy {
        static void Destroy(ssl_st* ssl) noexcept;
        static void Destroy(bio_st* bio) noexcept;
    };

    IIoInput* Input_;
    IIoOutput* Output_;
    const NAddr::IRemoteAddr* RemoteAddr_ = nullptr;

    const THolder<ssl_st, TDestroy> Ssl_;
    const THolder<bio_st, TDestroy> Bio_;

    TInstant RecvDeadline_ = TInstant::Max();
    TInstant SendDeadline_ = TInstant::Max();

    bool Established_ = false;
    bool UsedTlsTickets_ = false;
    TString TicketName_;
    TString TicketIV_;
    TError StoredError;
    bool ExpEnabled_ = false;
    bool CpuLimiterEnabled_ = true;
    TChunkList EarlyDataChunkList_;
    bool EarlyDataEnabled_ = false;
    bool EarlyDataFinished_ = false;
    bool EarlyDataProcessed_ = false;

    ui32 LegacyVersion_ = 0;
    TVector<ui8> Ciphers_;
    TVector<int> ClientExtensions_;
    TVector<ui8> EllipticCurves_;
    TVector<ui8> EllipticCurvesPointFormats_;
    TVector<ui8> SignatureAlgorithms_;
    TVector<ui8> SignatureAlgorithmsCert_;
    TVector<ui8> SupportedVersions_;
    TVector<ui8> ApplicationLayerProtocolNegotiation_;
    TVector<ui8> KeyShare_;
    TVector<ui8> PskKeyExchangeModes_;
};

bool CheckOcspResponse(const TChunkPtr& response) noexcept;

}  // namespace NSrvKernel
