#include "jwt.h"

#include "auth.h"

#include <library/cpp/http/cookies/cookies.h>

TJwtAuthConfig::TJwtAuthConfig(const TString& name)
    : IAuthModuleConfig(name)
{
}

THolder<IAuthModule> TJwtAuthConfig::ConstructAuthModule(const IServerBase* server) const {
    Y_UNUSED(server);
    return MakeHolder<TJwtAuthModule>(*this);
}

void TJwtAuthConfig::Init(const TYandexConfig::Section* section) {
    const auto subsections = Yensured(section)->GetAllChildren();
    const auto jwks = subsections.equal_range("jwk");
    for (const auto& i : MakeIteratorRange(jwks)) {
        const auto subsection = i.second;
        const auto& directives = Yensured(subsection)->GetDirectives();
        NDrive::TJwkDescription jwk;
        jwk.Algorithm = directives.Value("Algorithm", jwk.Algorithm);
        jwk.KeyId = directives.Value("KeyId", jwk.KeyId);
        jwk.PEM = directives.Value("PEM", jwk.PEM);
        Keys.push_back(std::move(jwk));
    }
}

void TJwtAuthConfig::ToString(IOutputStream& os) const {
    for (auto&& jwk : Keys) {
        os << "<jwk>" << Endl;
        os << "Algorithm: " << jwk.Algorithm << Endl;
        os << "KeyId: " << jwk.KeyId << Endl;
        os << "PEM: " << jwk.PEM << Endl;
        os << "</jwk>" << Endl;
    }
}

TJwtAuthModule::TJwtAuthModule(const TJwtAuthConfig& config) {
    for (auto&& key : config.GetKeys()) {
        Keys.emplace_back(key);
    }
}

IAuthInfo::TPtr TJwtAuthModule::RestoreAuthInfo(IReplyContext::TPtr requestContext) const {
    if (!requestContext) {
        return MakeAtomicShared<TJwtAuthInfo>("null RequestContext");
    }

    const TServerRequestData& rd = requestContext->GetRequestData();

    std::string_view auth;
    auto authorizationHeader = rd.HeaderInOrEmpty("Authorization");
    if (auth.empty() && authorizationHeader) {
        constexpr auto bearerPrefix = "Bearer "sv;
        if (!authorizationHeader.SkipPrefix(bearerPrefix)) {
            return MakeAtomicShared<TJwtAuthInfo>("incorrect Authorization header");
        }
        auth = authorizationHeader;
    }

    auto cookieHeader = rd.HeaderInOrEmpty("Cookie");
    if (auth.empty() && cookieHeader) {
        THttpCookies cookies(cookieHeader);
        auth = cookies.Get("vidtoken");
    }

    auto token = jwt::decode(std::string{auth});
    if (!Keys.empty()) {
        NDrive::ValidateToken(token, Keys);
    }

    auto parsed = NDrive::ParseCognitoIdToken(token);
    return MakeAtomicShared<TJwtAuthInfo>(std::move(parsed));
}

IAuthModuleConfig::TFactory::TRegistrator<TJwtAuthConfig> TJwtAuthConfig::Registrator("jwt");
