#pragma once

#include "response.h"

#include <ymod_tvm/module.h>
#include <yplatform/find.h>

namespace collectors::web::internal {

template <typename Method>
class check_tvm
{
public:
    check_tvm(const Method& method, settings_ptr settings) : method_(method), settings_(settings)
    {
        if (settings_->tvm_module.size())
        {
            tvm = yplatform::find<ymod_tvm::tvm2_module>(settings_->tvm_module);
        }
    }

    template <typename... Args>
    void operator()(ymod_webserver::http::stream_ptr stream, Args&&... args)
    {
        // Empty tvm pointer means checking is disabled.
        if (tvm && (settings_->protect_localhost || !stream->remote_addr().is_loopback()))
        {
            auto& headers = stream->request()->headers;
            auto& log_data = stream->ctx()->custom_log_data;
            if (auto it = headers.find("x-ya-service-ticket"); it != headers.end())
            {
                if (auto ticket = tvm->get_native_service_ticket(stream->ctx(), it->second))
                {
                    if (!*ticket)
                    {
                        ymod_tvm::error_code err = ticket->GetStatus();
                        YLOG_CTX_GLOBAL(stream->ctx(), info)
                            << "tvm check failed: " << err.message();
                        respond(stream, err, ymod_webserver::codes::unauthorized);
                        return;
                    }
                    auto src = ticket->GetSrc();
                    // Log tvm ssh user uid, if any.
                    if (auto uid = ticket->GetIssuerUid())
                    {
                        log_data["tvm_uid"] = std::to_string(uid);
                    }
                    if (auto it = settings_->allowed_sources.find(src);
                        it != settings_->allowed_sources.end())
                    {
                        // Log service by friendly name if available.
                        log_data["tvm_src"] = it->second;
                    }
                    else
                    {
                        log_data["tvm_src"] = std::to_string(src);
                        respond(
                            stream,
                            code::not_allowed_tvm_source,
                            ymod_webserver::codes::unauthorized);
                        return;
                    }
                }
                else
                {
                    respond(stream, code::tvm_error, ymod_webserver::codes::unauthorized);
                    return;
                }
            }
            else
            {
                respond(stream, code::no_auth_token, ymod_webserver::codes::unauthorized);
                return;
            }
        }
        method_(stream, std::forward<Args>(args)...);
    }

private:
    typename std::decay<Method>::type method_;
    settings_ptr settings_;
    boost::shared_ptr<ymod_tvm::tvm2_module> tvm;
};

}
