#include "sni.h"
#include "stats.h"
#include "matcher.h"
#include "tickets.h"
#include "client_settings.h"
#include "sslitem.h"

#include <balancer/kernel/fs/shared_files.h>
#include <balancer/kernel/fs/threadedfile.h>
#include <balancer/kernel/fs/watched_state.h>
#include <balancer/kernel/log/errorlog.h>
#include <balancer/kernel/memory/chunks.h>
#include <balancer/kernel/module/iface.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/net/socket.h>
#include <balancer/kernel/process/thread_info.h>
#include <balancer/kernel/ssl/protocol_version.h>
#include <balancer/kernel/ssl/sslio.h>
#include <balancer/kernel/ssl/sslextdataindeces.h>
#include <balancer/kernel/thread/threadedqueue.h>

#include <library/cpp/config/sax.h>
#include <library/cpp/containers/flat_hash/flat_hash.h>

#include <openssl/ssl.h>
#include <openssl/x509.h>

#include <util/datetime/base.h>
#include <util/generic/hash.h>
#include <util/generic/map.h>
#include <util/generic/maybe.h>
#include <util/generic/ptr.h>
#include <util/generic/scope.h>
#include <util/generic/set.h>
#include <util/generic/strbuf.h>
#include <util/generic/string.h>
#include <util/stream/str.h>
#include <util/thread/singleton.h>

#include <utility>

using namespace NConfig;
using namespace NSrvKernel;
using namespace NModSsl;
using namespace NSsl;

constexpr const int DEFAULT_CONTEXT_ID = 0;

Y_TLS(ssl_sni) {
    TTls(TSharedCounter emptyRequests, TSharedCounter httpsRequests, TSharedCounter httpRequests,
         TSharedCounter droppedExperiments, TSharedCounter errorsTotal, TSharedCounter zeroErrors,
         NFH::TFlatHashMap<EProtocolVersion, THolder<NSrvKernel::TSharedCounter>>& protocolVersions,
         TSharedStatsManager& statsManager, size_t workerId)
        : Stats(std::move(emptyRequests), std::move(httpsRequests), std::move(httpRequests), std::move(droppedExperiments), std::move(errorsTotal),
            std::move(zeroErrors), protocolVersions, statsManager, workerId)
    {}

    NSrvKernel::TWatchedState<double> H2AlpnFreq = 0;
    NSrvKernel::TWatchedState<EH2AlpnRandMode> H2AlpnRandMode = EH2AlpnRandMode::IpHash;
    NSsl::TStats Stats;
};

MODULE_WITH_TLS(ssl_sni) {
private:
    class TH2ProtoSelector;
    using TItems = TVector<TSslItem>;
    using TItemsMap = NFH::TFlatHashMap<ui64, TItems>;
    using TDefaultItems = NFH::TFlatHashMap<ui64, TSslItem*>;

    TItemsMap Items_;
    TDefaultItems DefaultItems_;

    THolder<IModule> Submodule_;
    THolder<TH2ProtoSelector> H2ProtoSelector_;

    NSrvKernel::TSharedStatsManager& StatsManager_;

    NSrvKernel::TSharedCounter EmptyRequests_;
    NSrvKernel::TSharedCounter HttpsRequests_;
    NSrvKernel::TSharedCounter HttpRequests_;
    NSrvKernel::TSharedCounter DroppedExperiments_;
    NSrvKernel::TSharedCounter ErrorsTotal_;
    NSrvKernel::TSharedCounter ZeroErrors_;
    NFH::TFlatHashMap<EProtocolVersion, THolder<NSrvKernel::TSharedCounter>> ProtocolVersions_;

    TString H2AlpnFreqFile_;
    double H2AlpnFreqDefault_ = 0;

    TString H2AlpnRandModeFile_;
    EH2AlpnRandMode H2AlpnRandModeDefault_ = EH2AlpnRandMode::IpHash;

    ui64 H2AlpnExpId_ = 0;

    bool HaveH2_ = false;
    bool ForceSsl_ = true;
    bool Ja3Enabled_ = false;
    bool ValidateCertDate_ = false;

    TMaybe<long> MaxSendFragment_;

    TSslEarlyDataParams EarlyDataParams_ = {false, 16384, 16384};
private:
    class TH2ProtoSelector: public IAlpnProtos {
    public:
        TH2ProtoSelector(TModule& parent)
            : Parent_(parent)
        {}

        bool HasH2(const NAddr::IRemoteAddr* addr, bool expEnabled) const noexcept override {
            return Parent_.HasH2(addr, expEnabled);
        }

    private:
        const TModule& Parent_;
    };

    class TExperiments : public NConfig::IConfig::IFunc {
    private:
        class TExperimentContexts : public NConfig::IConfig::IFunc {
        public:
            TExperimentContexts(TString name, const TModuleParams& mp, TItemsMap& items, bool ja3Enabled, bool validateCertDate)
                : Name_(std::move(name))
                , ModuleParams_(mp)
                , Items_(items)
                , Ja3Enabled_(ja3Enabled)
                , ValidateCertDate_(validateCertDate)
            {
                ModuleParams_.Config->ForEach(this);

                // It's hard to copy IConfig, rewrite it after parsing
                for (auto& item: ItemsTmpl_) {
                    item.SetItemsSection(ExpId_);
                }

                if (MaxSendFragment_.Defined()) {
                    for (auto &it: ItemsTmpl_) {
                        it.SetMaxSendFragment(*MaxSendFragment_);
                    }
                }

                Items_[ExpId_] = std::move(ItemsTmpl_);
            }

            ui64 GetExpId() const noexcept {
                return ExpId_;
            }
        public:
            START_PARSE {
                ON_KEY("exp_id", ExpId_) {
                    if (ExpId_ == 0) {
                        ythrow TConfigParseError() << " experiment id could not be 0";
                    }
                    return;
                }

                int maxSendFragment = 0;
                ON_KEY("max_send_fragment", maxSendFragment) {
                    if (maxSendFragment < SSL3_RT_MIN_PLAIN_LENGTH || maxSendFragment > SSL3_RT_MAX_PLAIN_LENGTH) {
                        ythrow TConfigParseError() << " Max send frame should be in range 512 - 16384";
                    }
                    MaxSendFragment_ = maxSendFragment;
                    return;
                }

                if (key == "contexts") {
                    ParseMap(value->AsSubConfig(), [this](const auto& key, auto* val) {
                        ItemsTmpl_.emplace_back(
                            Items_, DEFAULT_CONTEXT_ID, key, ModuleParams_.Copy(val->AsSubConfig()), Ja3Enabled_, ValidateCertDate_);
                    });

                    return;
                }

            } END_PARSE;

        private:
            TString Name_;
            const TModuleParams& ModuleParams_;
            ui64 ExpId_ = 0;
            TItemsMap& Items_;
            TVector<TSslItem> ItemsTmpl_;
            TMaybe<long> MaxSendFragment_;
            bool Ja3Enabled_ = false;
            bool ValidateCertDate_ = false;
        };

    public:
        TExperiments(const TModuleParams& mp, TItemsMap& items, bool ja3Enabled, bool validateCertDate)
            : ModuleParams_(mp)
            , Items_(items)
            , Ja3Enabled_(ja3Enabled)
            , ValidateCertDate_(validateCertDate)
        {}

    public:
        START_PARSE {
            TExperimentContexts expContexts(key, ModuleParams_.Copy(value->AsSubConfig()), Items_, Ja3Enabled_, ValidateCertDate_);

            ui64 expId = expContexts.GetExpId();
            if (expIds_.contains(expId)) {
                ythrow TConfigParseError() << "experiments must have unique ids";
            }

            expIds_.insert(expId);
            return;
        } END_PARSE;

    private:
        const TModuleParams& ModuleParams_;
        TItemsMap& Items_;
        TSet<ui64> expIds_;
        bool Ja3Enabled_ = false;
        bool ValidateCertDate_ = false;
    };

public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
        , StatsManager_(mp.Control->SharedStatsManager())
        , EmptyRequests_(StatsManager_.MakeCounter("ssl_sni-empty_requests").AllowDuplicate().Build())
        , HttpsRequests_(StatsManager_.MakeCounter("ssl_sni-https_requests").AllowDuplicate().Build())
        , HttpRequests_(StatsManager_.MakeCounter("ssl_sni-http_requests").AllowDuplicate().Build())
        , DroppedExperiments_(StatsManager_.MakeCounter("ssl_sni-dropped_experiments").AllowDuplicate().Build())
        , ErrorsTotal_(StatsManager_.MakeCounter("ssl_error-total").AllowDuplicate().Build())
        , ZeroErrors_(StatsManager_.MakeCounter("ssl_error-zero_code").AllowDuplicate().Build())
        , ProtocolVersions_{}
    {
        ProtocolVersions_.reserve(GetEnumItemsCount<EProtocolVersion>());
        for (const auto e : GetEnumAllValues<EProtocolVersion>()) {
            ProtocolVersions_[e] = MakeHolder<TSharedCounter>(StatsManager_.MakeCounter("ssl_sni-" + ToString(e) + "_requests").AllowDuplicate().Build());
        }
        for (const auto& errorDescription : TSslErrorStats::Instance().ActiveErrors()) {
            StatsManager_.MakeCounter("ssl_error-" + errorDescription).AllowDuplicate().Build();
        }

        Config->ForEach(this);

        if (Items_.empty()) {
            ythrow TConfigParseError() << "no modules configured";
        }

        if ((H2AlpnRandModeDefault_ == EH2AlpnRandMode::ExpStatic || H2AlpnRandModeFile_)
            && H2AlpnExpId_ == 0)
        {
            ythrow TConfigParseError() << "empty exp_static experiment id";
        }

        if (HaveH2_) {
            H2ProtoSelector_.Reset(MakeHolder<TH2ProtoSelector>(*this));
        } else { // no need to read probability file in abscense of http2 mod itself
            H2AlpnFreqFile_.clear();
        }

        if (EarlyDataParams_.MaxEarlyData == 0 || EarlyDataParams_.RecvMaxEarlyData == 0) {
            EarlyDataParams_.Enabled = false;
        }

        /*
         * TODO: remove events handlers from TModule
         * Right now events are stored in the class. To trigger handler we
         * should iterate through all modules and match event name by regexp.
         * It will be better to store callbacks in Trie outside the class.
         */
        for (auto&[key, items]: Items_) {
            for (auto& ssl_item: items) {
                // In case when early data defined inside context item do not
                // update max data/max recv data parameters
                if (EarlyDataParams_.Enabled && !ssl_item.IsEarlyDataEnabled()) {
                    ssl_item.SetItemEarlyDataEnabled(true);
                    ssl_item.SetItemEarlyDataMaxData(EarlyDataParams_.MaxEarlyData);
                    ssl_item.SetItemEarlyDataRecvMaxData(EarlyDataParams_.RecvMaxEarlyData);
                }
                for(auto& event: ssl_item.Events()) {
                    RegisterEvent(event.Regexp, event.Name, event.Event, &ssl_item);
                }
            }
            DefaultItems_[key] = FindDefaultItem(items);
        }

        RegisterEvent("reload_ticket_keys", "ssl.ReloadTicketKeys", &TModule::ReloadTicketKeys, this);
        RegisterEvent("force_reload_ticket_keys", "ssl.ForceReloadTicketKeys", &TModule::ForceReloadTicketKeys, this);
        RegisterEvent("reload_ocsp_response", "ssl.ReloadOcspResponse", &TModule::ReloadOcspResponse, this);
    }

private:
    START_PARSE {
        if (key == "contexts") {
            ParseMap(value->AsSubConfig(), [this](const auto& key, auto* val) {
                Items_[DEFAULT_CONTEXT_ID].emplace_back(
                    Items_, DEFAULT_CONTEXT_ID, key, Copy(val->AsSubConfig()), Ja3Enabled_, ValidateCertDate_
                );
            });
            return;
        }

        if (key == "events") {
            ParseMap(value->AsSubConfig(), [this](const auto& key, auto* value) {
                if (key == "reload_ticket_keys") {
                    RegisterEvent(value->AsString(), "ssl_sni.ReloadTicketKeys", &TModule::ReloadTicketKeys, this);
                } else if (key == "force_reload_ticket_keys") {
                    RegisterEvent(value->AsString(), "ssl_sni.ForceReloadTicketKeys", &TModule::ForceReloadTicketKeys, this);
                } else if (key == "reload_ocsp_response") {
                    RegisterEvent(value->AsString(), "ssl_sni.ReloadOcspResponse", &TModule::ReloadOcspResponse, this);
                }
            });
            return;
        }

        ON_KEY("force_ssl", ForceSsl_) {
            return;
        }

        ON_KEY("http2_alpn_file", H2AlpnFreqFile_) {
            return;
        }

        ON_KEY("http2_alpn_freq", H2AlpnFreqDefault_) {
            return;
        }

        ON_KEY("http2_alpn_rand_mode_file", H2AlpnRandModeFile_) {
            return;
        }

        ON_KEY("http2_alpn_rand_mode", H2AlpnRandModeDefault_) {
            return;
        }

        ON_KEY("http2_alpn_exp_id", H2AlpnExpId_) {
            return;
        }

        int maxSendFragment = 0;
        ON_KEY("max_send_fragment", maxSendFragment) {
            if (maxSendFragment < SSL3_RT_MIN_PLAIN_LENGTH || maxSendFragment > SSL3_RT_MAX_PLAIN_LENGTH) {
                ythrow TConfigParseError() << " Max send frame should be in range 512 - 16384";
            }
            MaxSendFragment_ = maxSendFragment;
            return;
        }

        ON_KEY("ja3_enabled", Ja3Enabled_) {
            return;
        }

        ON_KEY("validate_cert_date", ValidateCertDate_) {
            return;
        }

        ON_KEY("earlydata_enabled", EarlyDataParams_.Enabled) {
            return;
        }

        ON_KEY("earlydata_max", EarlyDataParams_.MaxEarlyData) {
            return;
        }

        ON_KEY("earlydata_recv_max", EarlyDataParams_.RecvMaxEarlyData) {
            return;
        }

        if (key == "exps") {
            auto experiments = MakeHolder<TExperiments>(*this, Items_, Ja3Enabled_, ValidateCertDate_);
            value->AsSubConfig()->ForEach(experiments.Get());
            return;
        }

        if (key == "http2") {
            HaveH2_ = true;
        }

        Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
        return;
    } END_PARSE

    void DoConsumeEvent(TString handler, TString regexp) override {
        if (handler == "reload_ticket_keys") {
            RegisterEvent(regexp, "ssl_sni.ReloadTicketKeys", &TModule::ReloadTicketKeys, this);
        } else if (handler == "force_reload_ticket_keys") {
            RegisterEvent(regexp, "ssl_sni.ForceReloadTicketKeys", &TModule::ForceReloadTicketKeys, this);
        } else if (handler == "reload_ocsp_response") {
            RegisterEvent(regexp, "ssl_sni.ReloadOcspResponse", &TModule::ReloadOcspResponse, this);
        }
    }

    TSslItem* FindDefaultItem(TItems &items) {
        TSslItem* defaultItem = nullptr;

        TSet<int> priorities;
        for (auto& it : items) {
            if (it.IsDefault()) {
                defaultItem = &it;
            }
            if (priorities.contains(it.Priority())) {
                ythrow TConfigParseError() << " no ssl item should have the same priority";
            }
            priorities.insert(it.Priority());

            if (HaveH2_) {
                it.EnableH2(H2ProtoSelector_.Get());
            }

            if (MaxSendFragment_.Defined()) {
                it.TryUpdateMaxSendFragment(*MaxSendFragment_);
            }
        }

        if (defaultItem == nullptr) {
            ythrow TConfigParseError() << "there should be default module";
        }

        return defaultItem;
    }

    TSslItem* SelectDefaultItem(const THashMap<ui64, i64>& experiments, TTls& tls) const {
        auto it = DefaultItems_.find(DEFAULT_CONTEXT_ID);
        if (it == DefaultItems_.end()) {
            return nullptr;
        }

        TSslItem* item = nullptr;
        for (const auto& [expId, expDefaultItem] : DefaultItems_) {
            if (experiments.contains(expId)) {
                if (item == nullptr) {
                    item = expDefaultItem;
                } else {
                    // If there are several experiments to take part in, drop all of them.
                    ++tls.Stats.DroppedExperiments;
                    item = nullptr;
                    break;
                }
            }
        }

        if (item == nullptr) {
            item = it->second;
        }

        return item;
    }

    void ReloadOcspResponse(TEventData& event) noexcept {
        event.RawOut() << "Triggering ocsp response reloading\n";
        for (auto&[key, lst] : Items_) {
            for (auto &item: lst) {
                item.ReloadPrimaryOcspResponse(event);
                item.ReloadSecondaryOcspResponse(event);
            }
        }
    }

    void ReloadTicketKeys(TEventData& event) noexcept {
        event.RawOut() << "Triggering tls session ticket key reloading\n";
        for (auto&[key, lst] : Items_) {
            for (auto &item: lst) {
                item.ReloadTicketKeys(event.RawOut(), false);
            }
        }
    }

    void ForceReloadTicketKeys(TEventData& event) noexcept {
        event.RawOut() << "Triggering tls session ticket key force reloading\n";
        for (auto&[key, lst] : Items_) {
            for (auto &item: lst) {
                item.ReloadTicketKeys(event.RawOut(), true);
            }
        }
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const noexcept override {
        TChunkList lst;

        Y_PROPAGATE_ERROR(descr.Input->Recv(lst, TInstant::Max()));

        if (lst.Empty()) {
            ++tls.Stats.EmptyRequests;
            return Y_MAKE_ERROR(yexception{} << "empty ssl request");
        }

        const bool encrypted = *(lst.Front()->Data()) < 0x20;

        descr.Input->UnRecv(std::move(lst));

        if (encrypted) {
            ++tls.Stats.HttpsRequests;

            {
                bool expEnabled = false;
                const auto& experiments = descr.Properties->Parent.Experiments;
                if (experiments.contains(H2AlpnExpId_)) {
                    expEnabled = true;
                }

                bool cpuLimiterEnabled = true;
                if (descr.CpuLimiter() && descr.CpuLimiter()->CheckHTTP2Disabled()) {
                    cpuLimiterEnabled = false;
                }

                try {
                    TSslItem* item = SelectDefaultItem(experiments, tls);
                    Y_PROPAGATE_ERROR(DoRunEncrypted(item, descr, expEnabled, cpuLimiterEnabled));
                } Y_TRY_STORE(TSslError, yexception);
            }

        } else {
            Y_PROPAGATE_ERROR(DoRunUnencrypted(descr));
        }
        return {};
    }

    TError DoRunEncrypted(const TSslItem* item, const TConnDescr& descr, bool expEnabled, bool cpuLimiterEnabled) const {
        TSslIo io(item->Ctx(descr.Process()), *descr.Input, *descr.Output, &descr.RemoteAddr(), expEnabled,
                  cpuLimiterEnabled, item->ItemEarlyDataParams());

        auto& tls = GetTls(&descr.Process());

        TInstant handshakeStarted = TInstant::Now();
        Y_TRY(TError, error) {
            Y_PROPAGATE_ERROR(io.Accept());
            ++*tls.Stats.ProtocolVersions[io.GetProtocolVersion()];
            descr.Ssl().HandshakeCompleted = TInstant::Now();
            descr.Ssl().HandshakeDuration = descr.Ssl().HandshakeCompleted - handshakeStarted;
            descr.Ssl().HandshakeUsedTlsTickets = io.UsedTlsTickets();
            descr.Ssl().NextProto = io.AlpnProto();
            descr.Ssl().CurrentCipherId = io.CipherId();
            descr.Ssl().CurrentProtocolId = static_cast<ui16>(io.GetProtocolVersion());
            descr.Ssl().EarlyData = io.EarlyData();
            descr.Ssl().TicketName = io.TicketName();
            descr.Ssl().TicketIV = io.TicketIV();

            return {};
        } Y_CATCH {
            if (const auto* e = error.GetAs<TSslError>()) {
                tls.Stats.RegisterError(e->Status());
            }
            LOG_ERROR(TLOG_ERR, descr, "ssl handshake failed" << GetErrorMessage(error));
            return error;
        }

        descr.Ssl().ThisConnIsSsl = true;

        TSslClientCertData clientCertData;
        // Copy client cert data if it was present in ssl context instead of checking
        // TSslItem HasClientCert because for SNI TSslIo ssl_ctx will be swapped to proper
        // TSslItem context in ServernameCallback.
        if (Y_UNLIKELY(io.ClientCertPresent())) {
            clientCertData.CN = io.ClientCertCN();
            clientCertData.Subject = io.ClientCertSubject();
            clientCertData.VerifyResult = io.ClientCertVerifyResult();
            clientCertData.SerialNumber = io.ClientCertSerialNumber();
            descr.Ssl().ClientCertData = &clientCertData;
        }

        TSslJa3Data ja3Data(io.LegacyVersion(), io.Ciphers(), io.ClientExtensions(),
            io.EllipticCurvers(), io.EllipticCurversPointFormats(),
            io.SignatureAlgorithms(), io.SignatureAlgorithmsCert(),
            io.SupportedVersions(), io.ApplicationLayerProtocolNegotiation(),
            io.KeyShare(), io.PskKeyExchangeModes());
        descr.Ssl().Ja3Data = &ja3Data;

        try {
            UpdateStats(descr, io);
        } Y_TRY_STORE(yexception)

        Y_TRY(TError, error) {
            Y_VERIFY(Submodule_);
            return Submodule_->Run(descr.Copy(io, io));
        } Y_CATCH {
            if (const auto* e = error.GetAs<TSslError>()) {
                tls.Stats.RegisterError(e->Status());
            }
            return error;
        }

        return {};
    }

    TError DoRunUnencrypted(const TConnDescr& descr) const {
        ++GetTls(&descr.Process()).Stats.HttpRequests;
        LOG_ERROR(TLOG_ERR, descr, "Unencrypted query");
        Y_REQUIRE(!ForceSsl_,
                  yexception{} << "Unencrypted query");
        Y_VERIFY(Submodule_);
        return Submodule_->Run(descr);
    }

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        for (auto& [key, value] : Items_) {
            for (auto& it : value) {
                it.Init(*process);
            }
        }

        const size_t id = process->WorkerId();

        auto tls = MakeHolder<TTls>(
            TSharedCounter(EmptyRequests_, id), TSharedCounter(HttpsRequests_, id), TSharedCounter(HttpRequests_, id),
            TSharedCounter(DroppedExperiments_, id), TSharedCounter(ErrorsTotal_, id), TSharedCounter(ZeroErrors_, id),
            ProtocolVersions_, StatsManager_, id);

        tls->H2AlpnFreq = NSrvKernel::TWatchedState<double>(
            H2AlpnFreqDefault_, H2AlpnFreqFile_, *process->SharedFiles());

        tls->H2AlpnRandMode = NSrvKernel::TWatchedState<EH2AlpnRandMode>(
            H2AlpnRandModeDefault_, H2AlpnRandModeFile_, *process->SharedFiles());

        return tls;
    }

    bool DoCanWorkWithoutHTTP() const noexcept override {
        return true;
    }

    void UpdateStats(const TConnDescr& descr, const TSslIo& io) const {
        auto ctx = io.CtxRaw();

        if (ctx == nullptr) {
            return;
        }

        for (const auto&[key, lst] : Items_) {
            for (const auto& item: lst) {
                if (item.Ctx(descr.Process()).Ctx() == ctx) {
                    item.LogCipherStat(descr, io);
                    break;
                }
            }
        }
    }

    bool HasH2(const NAddr::IRemoteAddr* addr, bool expEnabled) const noexcept {
        Y_ASSERT(FastTlsSingleton<NSrvKernel::NProcessCore::TThreadInfo>()->WorkerProcess);
        auto& tls = GetTls(FastTlsSingleton<NSrvKernel::NProcessCore::TThreadInfo>()->WorkerProcess);
        if (HaveH2_) {
            switch (tls.H2AlpnRandMode.Get()) {
                case EH2AlpnRandMode::Rand:
                    return RandomNumber<double>() < tls.H2AlpnFreq.Get();
                case EH2AlpnRandMode::IpHash:
                    return addr && ((double) IpHash(*addr) < (tls.H2AlpnFreq.Get() * MaxCeil<ui64>()));
                case EH2AlpnRandMode::ExpStatic:
                    return expEnabled;
            }
        } else {
            return false;
        }
    }
};

IModuleHandle* NModSsl::SniHandle() {
    return TModule::Handle();
}
