#pragma once

#include <yxiva/core/types.h>
#include <ymod_tvm/module.h>
#include <ymod_tvm/settings.h>
#include <ymod_webserver/response.h>
#include <yplatform/find.h>
#include <yplatform/ptree.h>
#include <unordered_map>

namespace yxiva {
struct tvm_checker_settings
{
    void parse_ptree(const yplatform::ptree& conf)
    {
        module = conf.get("module", module);
        auto mappings = ymod_tvm::tvm2::settings::read_mappings(conf);
        auto range = conf.equal_range("allowed_sources");
        for (auto it = range.first; it != range.second; ++it)
        {
            if (auto alias = it->second.get_value(string{}); mappings.count(alias))
            {
                allowed_sources[mappings[alias].id] = alias;
            }
            else
            {
                allowed_sources[it->second.get<uint32_t>("id")] =
                    it->second.get<std::string>("name");
            }
        }
        protect_localhost = conf.get("protect_localhost", protect_localhost);
    }

    string module;
    // Allowed sources friendly names by tvm id.
    std::unordered_map<uint32_t, std::string> allowed_sources;
    bool protect_localhost = false;
};

template <typename Handler>
struct check_tvm
{
    check_tvm(const tvm_checker_settings& s, Handler&& h)
        : settings(s), handler(std::forward<Handler>(h))
    {
        if (settings.module.size())
        {
            tvm = yplatform::find<ymod_tvm::tvm2_module, std::shared_ptr>(settings.module);
        }
    }

    template <typename Stream, typename... Args>
    void operator()(Stream stream, Args&&... args)
    {
        // Empty tvm pointer means checking is disabled.
        if (tvm && (settings.protect_localhost || !stream->remote_addr().is_loopback()))
        {
            auto& headers = stream->request()->headers;
            auto& log_data = stream->ctx()->custom_log_data;
            if (auto it = headers.find("x-ya-service-ticket"); it != headers.end())
            {
                if (auto ticket = tvm->get_native_service_ticket(stream->ctx(), it->second))
                {
                    if (!*ticket)
                    {
                        ymod_tvm::error_code err = ticket->GetStatus();
                        YLOG_CTX_GLOBAL(stream->ctx(), info)
                            << "tvm check failed: " << err.message();
                        WEB_RESPONSE_LOG_G(info, stream, forbidden, "invalid tvm ticket");
                        return;
                    }
                    auto src = ticket->GetSrc();
                    // Log tvm ssh user uid, if any.
                    if (auto uid = ticket->GetIssuerUid())
                    {
                        log_data["tvm_uid"] = std::to_string(uid);
                    }
                    if (auto it = settings.allowed_sources.find(src);
                        it != settings.allowed_sources.end())
                    {
                        // Log service by friendly name if available.
                        log_data["tvm_src"] = it->second;
                    }
                    else
                    {
                        log_data["tvm_src"] = std::to_string(src);
                        WEB_RESPONSE_LOG_G(info, stream, forbidden, "tvm src not allowed");
                        return;
                    }
                }
                else
                {
                    // Actual error is logged by tvm module.
                    return;
                }
            }
            else
            {
                WEB_RESPONSE_LOG_G(info, stream, forbidden, "missing tvm header");
                return;
            }
        }
        handler(stream, std::forward<Args>(args)...);
    }

    tvm_checker_settings settings;
    std::shared_ptr<ymod_tvm::tvm2_module> tvm;
    std::decay_t<Handler> handler;
};

}
