#include "tickets.h"

constexpr const int SSL_SNI_TICKET_FAILED = 0;
constexpr const int SSL_SNI_TICKET_OK = 1;
constexpr const int SSL_SNI_TICKET_RENEW = 2;

namespace NModSsl {
    void TSslTicketKey::Encrypt(unsigned char* iv, evp_cipher_ctx_st* cipherCtx, hmac_ctx_st* hmacCtx) const noexcept {
        HMAC_Init_ex(hmacCtx, HmacKey, 16, EVP_sha256(), nullptr);
        EVP_EncryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, AesKey, iv);
    }

    void TSslTicketKey::Decrypt(unsigned char* iv, evp_cipher_ctx_st* cipherCtx, hmac_ctx_st* hmacCtx) const noexcept {
        HMAC_Init_ex(hmacCtx, HmacKey, 16, EVP_sha256(), nullptr);
        EVP_DecryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, AesKey, iv);
    }

    const TSslTicketKey* TSslTicketKeyStorage::Find(unsigned char* name) const noexcept {
        for (const auto& keys : Keys) {
            for (const auto& key: keys) {
                if (!memcmp(name, key.Name, 16)) {
                    return &key;
                }
            }
        }

        return nullptr;
    }

    int SslTicketKeysRoutine(unsigned char* name, unsigned char* iv, evp_cipher_ctx_st* cipherCtx, hmac_ctx_st* hmacCtx,
                             int init, const TSslTicketKeyStorage* ticketKeys) noexcept
    {
        if (!ticketKeys->Default()) {
            // Should never really happen as ssl_sni disables the extension in this case.
            // Note that returning 0 here makes OpenSSL 1.1.0 continue with the full handshake,
            // while 1.0.2 is buggy and treats 0 as 1 ("decrypted the session parameters") instead.
            return SSL_SNI_TICKET_FAILED;
        }

        if (init == 1) {
            // Generate ticket for the new client
            RAND_bytes(iv, EVP_MAX_IV_LENGTH);
            memcpy(name, ticketKeys->Default()->Name, 16);
            ticketKeys->Default()->Encrypt(iv, cipherCtx, hmacCtx);
            return SSL_SNI_TICKET_OK;
        } else {
            // If ticket was preset
            const TSslTicketKey* const key = ticketKeys->Find(name);
            // Check this key in current tickets list
            if (!key) {
                return SSL_SNI_TICKET_FAILED;
            }

            // If it's not a default key then renew it
            key->Decrypt(iv, cipherCtx, hmacCtx);
            if (key == ticketKeys->Default()) {
                return SSL_SNI_TICKET_OK;
            } else {
                return SSL_SNI_TICKET_RENEW;
            }
        }
    }

    void ThreadedReadTicketsKeyCallback::StoreKey(const unsigned char* data) {
        TSslTicketKey key;
        memcpy(key.Name, data, 16);
        memcpy(key.HmacKey, data + 16, 16);
        memcpy(key.AesKey, data + 32, 16);
        Keys.push_back(std::move(key));
    }

    void ThreadedReadTicketsKeyCallback::DoRun() {
        Y_TRY(NSrvKernel::TError, error) {
            const TFileStat stat(Path_);
            if (stat.Size != 48) {
                Y_PROPAGATE_ERROR(ReadPemKeys());
            } else { // for nginx compatibility
                Y_PROPAGATE_ERROR(ReadRawKey());
            }
            return {};
        } Y_CATCH {
            Error = std::move(error);
        };
    }

    [[nodiscard]] NSrvKernel::TError ThreadedReadTicketsKeyCallback::ReadRawKey() {
        Y_PROPAGATE_ERROR(NSrvKernel::CheckFileReadable(Path_));
        unsigned char buf[48];
        TUnbufferedFileInput in(Path_);
        const size_t bytes = in.Load(buf, sizeof(buf));
        if (bytes == sizeof(buf)) {
            StoreKey(buf);
        }
        return {};
    }

    [[nodiscard]] NSrvKernel::TError ThreadedReadTicketsKeyCallback::ReadPemKeys() {
        Y_PROPAGATE_ERROR(NSrvKernel::CheckFileReadable(Path_));
        struct TBioDeleter {
            static void Destroy(BIO* p) { BIO_free(p); }
        };
        const THolder<BIO, TBioDeleter> bio{BIO_new_file(Path_, "r")};

        if (bio == nullptr) {
            return Y_MAKE_ERROR(yexception{} << "Could not open file");
        }

        char* name = nullptr;
        char* header = nullptr;
        unsigned char* data = nullptr;
        long length = 0;
        ui64 ticketsCount = 0;

        while (PEM_read_bio(bio.Get(), &name, &header, &data, &length)) {
            if (length == 48 && strncmp("SESSION TICKET KEY", name, 18) == 0) {
                StoreKey(data);
            } else {
                return Y_MAKE_ERROR(yexception{} << "Could not read pem file");
            }

            OPENSSL_free(name);
            OPENSSL_free(header);
            OPENSSL_free(data);
            ++ticketsCount;
        }

        if (ticketsCount == 0) {
            return Y_MAKE_ERROR(yexception{} << "Empty or broken pem file");
        }
        return {};
    }

    NSrvKernel::TErrorOr<TList<TSslTicketKey>>
    ThreadedReadTicketsKey(const char* path, NSrvKernel::TThreadedQueue* queue, TInstant deadline) {
        ThreadedReadTicketsKeyCallback callback(path);

        if (auto returnedCallback = queue->Run(new ThreadedReadTicketsKeyCallback{path}, deadline, deadline)) {
            if (returnedCallback->Error) {
                return std::move(returnedCallback->Error);
            }
            return std::move(returnedCallback->Keys);
        }

        return Y_MAKE_ERROR(yexception{} << "Could not read file");
    }

    class TTicketsParser final : public NConfig::IConfig::IFunc {
    public:
        TTicketsParser(NConfig::IConfig* config) {
            config->ForEach(this);
        }

        TTicketFiles Construct() {
            auto getKey = [](const TTicketFile& ticket) { return -ticket.Priority; };
            SortBy(Tickets, getKey);
            return std::move(Tickets);
        }

    private:
        struct TTicketParser final : public NConfig::IConfig::IFunc {

            TTicketParser(NConfig::IConfig* config) {
                config->ForEach(this);
            }

            START_PARSE {
                    ON_KEY("priority", Ticket.Priority) {
                        return;
                    }

                    ON_KEY("keyfile", Ticket.FileName) {
                        return;
                    }
                } END_PARSE
            TTicketFile Ticket;
        };

        START_PARSE {

                TTicketParser parser(value->AsSubConfig());
                if (Priority_.contains(parser.Ticket.Priority)) {
                    ythrow NConfig::TConfigParseError()
                        << "tls ticket key file " << parser.Ticket.FileName << " has non unique priority (" << parser.Ticket.Priority  << ')';
                }
                Tickets.emplace_back(std::move(parser.Ticket));
                Priority_.insert(parser.Ticket.Priority);
                return;

            } END_PARSE

        TTicketFiles Tickets;
        THashSet<int> Priority_;
    };

    TTicketFiles ParseTickets(NConfig::IConfig* config) {
        TTicketsParser parser(config);
        return parser.Construct();
    }
}

