#include "oauth.h"

#include <passport/infra/daemons/blackbox/src/blackbox_impl.h>
#include <passport/infra/daemons/blackbox/src/grants/consumer.h>
#include <passport/infra/daemons/blackbox/src/grants/grants_checker.h>
#include <passport/infra/daemons/blackbox/src/helpers/base_result_helper.h>
#include <passport/infra/daemons/blackbox/src/helpers/oauth_attrs_helper.h>
#include <passport/infra/daemons/blackbox/src/helpers/partitions_helper.h>
#include <passport/infra/daemons/blackbox/src/helpers/strong_pwd_helper.h>
#include <passport/infra/daemons/blackbox/src/helpers/uid_helper.h>
#include <passport/infra/daemons/blackbox/src/loggers/authlog.h>
#include <passport/infra/daemons/blackbox/src/misc/attributes.h>
#include <passport/infra/daemons/blackbox/src/misc/db_fetcher.h>
#include <passport/infra/daemons/blackbox/src/misc/db_types.h>
#include <passport/infra/daemons/blackbox/src/misc/dbfields_converter.h>
#include <passport/infra/daemons/blackbox/src/misc/exception.h>
#include <passport/infra/daemons/blackbox/src/misc/experiment.h>
#include <passport/infra/daemons/blackbox/src/misc/session_utils.h>
#include <passport/infra/daemons/blackbox/src/misc/strings.h>
#include <passport/infra/daemons/blackbox/src/misc/utils.h>
#include <passport/infra/daemons/blackbox/src/oauth/config.h>
#include <passport/infra/daemons/blackbox/src/oauth/error.h>
#include <passport/infra/daemons/blackbox/src/oauth/fetcher.h>
#include <passport/infra/daemons/blackbox/src/oauth/status.h>
#include <passport/infra/daemons/blackbox/src/oauth/token_embedded_info.h>
#include <passport/infra/daemons/blackbox/src/oauth/token_info.h>
#include <passport/infra/daemons/blackbox/src/output/attributes_chunk.h>
#include <passport/infra/daemons/blackbox/src/output/oauth_result.h>
#include <passport/infra/daemons/blackbox/src/output/typed_value_result.h>
#include <passport/infra/daemons/blackbox/src/output/uid_chunk.h>
#include <passport/infra/daemons/blackbox/src/staff/staff_info.h>

#include <passport/infra/libs/cpp/auth_core/oauth_token.h>
#include <passport/infra/libs/cpp/auth_core/oauth_token_parser.h>
#include <passport/infra/libs/cpp/request/request.h>
#include <passport/infra/libs/cpp/tvm/common/private_key.h>
#include <passport/infra/libs/cpp/tvm/signer/signer.h>
#include <passport/infra/libs/cpp/utils/ipaddr.h>
#include <passport/infra/libs/cpp/utils/log/global.h>
#include <passport/infra/libs/cpp/utils/string/split.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>
#include <passport/infra/libs/cpp/xml/xml_utils.h>

namespace NPassport::NBb {
    static const TString AUTHORIZATION = "Authorization";

    TOAuthProcessor::TOAuthProcessor(const TBlackboxImpl& impl, const NCommon::TRequest& request)
        : Blackbox_(impl)
        , Request_(request)
    {
    }

    TGrantsChecker TOAuthProcessor::CheckGrants(const TConsumer& consumer, bool throwOnError) {
        TGrantsChecker checker(Request_, consumer, throwOnError);

        checker.CheckMethodAllowed(TBlackboxMethods::OAuth);

        checker.CheckHasArgAllowed(TStrings::ALLOW_FEDERAL, TBlackboxFlags::FederalAccounts);

        TPartitionsHelper::CheckGrants(Blackbox_.PartitionsSettings(), checker);
        TBaseResultHelper::CheckGrants(
            Blackbox_.DbFieldSettings(),
            Blackbox_.AttributeSettings(),
            TConsumer::ERank::HasCred,
            checker);

        TOAuthAttrsHelper::CheckGrants(checker);

        return checker;
    }

    std::unique_ptr<TOAuthResult> TOAuthProcessor::Process(const TConsumer& consumer) {
        CheckGrants(consumer);

        TUtils::CheckUserTicketAllowed(Request_);

        TUtils::CheckUserIpArg(Request_);

        return ProcessImpl(
            consumer,
            TOptions{
                .UserIp = TUtils::GetUserIpArg(Request_),
                .UserPort = TUtils::GetUserPortArg(Request_),
                .ScopesRequested = Request_.GetArg(TStrings::SCOPES),
                .GetUserTicket = TUtils::GetBoolArg(Request_, TStrings::GET_USER_TICKET),
                .GetLoginId = TUtils::GetBoolArg(Request_, TStrings::GET_LOGIN_ID),
                .FillBaseResult = true,
                .FillOAuthAttrs = true,
            });
    }

    std::unique_ptr<TOAuthResult> TOAuthProcessor::ProcessImpl(const TConsumer& consumer,
                                                               const TOAuthProcessor::TOptions& options) {
        UserIp_ = options.UserIp;

        TString token = Request_.GetArg(TStrings::OAUTH_TOKEN);
        if (token.empty()) {
            // parse token from header
            TString hdr = Request_.GetHeader(AUTHORIZATION);
            size_t delim = hdr.find(' ');
            if (hdr.empty() || delim == TString::npos) {
                throw TBlackboxError(TBlackboxError::EType::InvalidParams) << "Missing or empty oauth token";
            }
            delim = hdr.find_first_not_of(' ', delim);
            token.assign(hdr, delim, hdr.size());
        }

        // 'token' can be one of the 3 different types
        //
        // 1. old legacy token - [a-f0-9]+
        //    we have no info by token and need to search in oauth shard and then in oauth central
        //
        // 2. token with embedded info
        //        - base64url(uid, clid, shard, random), [a-zA-Z0-9_-]{39}
        //        - y(env)_base64url(shard, uid, clientid, token_id, random, crc32), y[0-4]_[a-zA-Z0-9_-]{55}
        //    we know uid, client_id and shard and can launch nonblocking queries in parallel
        //
        // 3. stateless token - ver.uid.client_id.expires.token_id.key_id.iv.data.tag, [a-zA-Z0-9_-.]+
        //    all info is parsed from token body but we may need to check xtoken data from db
        //
        // We try to guess type by token format and do as much async quering as possible

        TOAuthError error;
        TString uid;
        std::unique_ptr<NAuth::TOAuthToken> pStatelessToken;
        std::unique_ptr<TOAuthTokenEmbeddedInfo> embeddedInfo;

        TOAuthSingleFetcher oauthFetcher(Blackbox_.OauthConfig(), UserIp_, Request_.GetConsumerFormattedName(), Request_.GetRemoteAddr());
        std::optional<TOAuthAttrsHelper> attrsHelper;
        if (options.FillOAuthAttrs) {
            attrsHelper.emplace(oauthFetcher, Request_);
        }

        // Check if it is type 3 (stateless token)
        if (token.size() > 70 && token.StartsWith("1.")) {
            pStatelessToken = std::make_unique<NAuth::TOAuthToken>(
                Blackbox_.OauthParser().ParseToken(token));
            TExperiment::Get().RunTokenCheck();

            if (pStatelessToken->Status() != NAuth::TOAuthToken::VALID) {
                TLog::Debug("Error checking stateless token: status %d, message %s",
                            pStatelessToken->Status(),
                            pStatelessToken->ErrMsg().c_str());
                error.SetMsg("Error parsing token: " + pStatelessToken->ErrMsg(), pStatelessToken->Status());
                return GetErrorOAuth(error, nullptr, TStrings::EMPTY);
            }
            uid = pStatelessToken->Uid();

            if (pStatelessToken->Scopes().empty()) { // take scopes from client
                oauthFetcher.AddClientAttr(TOAuthClientAttr::SCOPE_IDS);
            }

            // may throw DbpoolError and show exception in result
            // if no xtoken_id in token - no shard request will be done, only to central for client attrs

            if (!oauthFetcher.SendNonBlockingRequestByAccess(pStatelessToken->XtokenId(), pStatelessToken->XtokenShard(), pStatelessToken->ClientId(), true)) {
                error.SetError(TOAuthError::TokenExpired);
                return GetErrorOAuth(error, nullptr, TStrings::EMPTY);
            }
        } else {
            // Check if it is type 2 (with embedded info)
            embeddedInfo = std::make_unique<TOAuthTokenEmbeddedInfo>(TOAuthTokenEmbeddedInfo::Parse(token));
            if (!embeddedInfo->IsOk()) {
                TLog::Debug() << "Error checking token with embedded info: " << embeddedInfo->ErrMsg();
                error.SetMsg("Error parsing token: malformed embedded info", TOAuthError::EError::OAuthReject);
                return GetErrorOAuth(error, nullptr, TStrings::EMPTY);
            }

            if (embeddedInfo->HasInfo()) {
                if (const auto& env = embeddedInfo->Environment(); env.has_value() && *env != Blackbox_.OAuthEnv()) {
                    error.SetMsg(
                        NUtils::CreateStr(
                            "Token was got in wrong environment: current environment is ",
                            Blackbox_.OAuthEnv(),
                            ", while token came from ",
                            ToString(*embeddedInfo->Environment())),
                        TOAuthError::EError::WrongEnvironment);
                    return GetErrorOAuth(error, nullptr, TStrings::EMPTY);
                }

                if (embeddedInfo->Uid() > 0) {
                    uid = IntToString<10>(embeddedInfo->Uid());
                }

                // for v2 tokens we know shard and client_id, let's start non-blocking request
                if (!oauthFetcher.SendNonBlockingRequestByAccess(token, embeddedInfo->Shard(), IntToString<10>(embeddedInfo->ClientId()), false)) {
                    error.SetError(TOAuthError::TokenExpired);
                    return GetErrorOAuth(error, nullptr, TStrings::EMPTY);
                }
            }
        }

        TPartitionsHelper::ParsePartitionArg(
            Blackbox_.PartitionsSettings(),
            Request_,
            TPartitionsHelper::TSettings{
                .Method = "oauth",
                .ForbidNonDefault = true,
            });

        TDbFetcher fetcher = Blackbox_.CreateDbFetcher();
        TDbFieldsConverter conv(fetcher, Blackbox_.Hosts(), Blackbox_.MailHostId());

        std::optional<TBaseResultHelper> baseResult;
        if (options.FillBaseResult) {
            baseResult.emplace(conv, Blackbox_, Request_);
        } else {
            // just add default aliases to fetcher: it allowes to find account
            TAccountHelper(conv, Request_);
        }
        TStrongPwdHelper strongPwd(fetcher);

        const TDbIndex availableIdx = fetcher.AddAttr(TAttr::ACCOUNT_IS_AVAILABLE);
        const TDbIndex glogoutIdx = fetcher.AddAttr(TAttr::ACCOUNT_GLOBAL_LOGOUT_DATETIME);
        const TDbIndex revokerTokensIdx = fetcher.AddAttr(TAttr::REVOKER_TOKENS);
        const TDbIndex changeReasonIdx = fetcher.AddAttr(TAttr::PASSWORD_FORCED_CHANGING_REASON);
        const TDbIndex createRequiredIdx = fetcher.AddAttr(TAttr::PASSWORD_CREATING_REQUIRED);
        const TDbIndex federalIdx = fetcher.AddAlias(TAlias::FEDERAL);

        // if token contains non-zero uid, let's fetch data from central and shard
        if (!uid.empty()) {
            fetcher.FetchByUid(uid);
        }

        // here we wait for non-blocking request to complete
        std::unique_ptr<TOAuthTokenInfo> tokenInfo;

        if (pStatelessToken) {
            // get xtoken result and fill tokenInfo from pStatelessToken
            tokenInfo = oauthFetcher.CheckTokenByAccess(*pStatelessToken, token, error);
        } else {
            tokenInfo = oauthFetcher.CheckTokenByAccess(*embeddedInfo, token, error);
        }

        if (!tokenInfo) {
            return GetErrorOAuth(error, nullptr, TStrings::EMPTY);
        }
        tokenInfo->ShowIsXTokenTrusted = TUtils::GetBoolArg(Request_, TStrings::GET_IS_XTOKEN_TRUSTED);

        // check token scopes
        if (!options.ScopesRequested.empty() && !tokenInfo->HasScopes(options.ScopesRequested, &error)) {
            return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::SCOPES_WRONG);
        }

        TString connectionId("t:");
        connectionId.append(tokenInfo->TokenId);

        if (tokenInfo->Uid.empty()) { // valid token without uid - reply without user data
            Blackbox_.OauthConfig().LogOAuth(*tokenInfo, UserIp_, TStrings::OK, TStrings::EMPTY, Request_.GetConsumerFormattedName(), Request_.GetRemoteAddr());

            std::unique_ptr<TAttributesChunk> tokenAttrsChunk;
            std::unique_ptr<TAttributesChunk> clientAttrsChunk;
            if (attrsHelper) {
                tokenAttrsChunk = attrsHelper->TokenAttrsChunk(tokenInfo.get());
                clientAttrsChunk = attrsHelper->ClientAttrsChunk(tokenInfo.get());
            }

            std::unique_ptr<TOAuthResult> oauthResult = std::make_unique<TOAuthResult>();
            oauthResult->OauthChunk.TokenInfo = std::move(tokenInfo);

            oauthResult->OauthChunk.Status = TOAuthStatus::Valid;
            oauthResult->OauthChunk.Comment = "OK";
            oauthResult->ConnectionId = connectionId;
            // No login_id for token with no uid, sorry

            oauthResult->OauthChunk.TokenAttrs = std::move(tokenAttrsChunk);
            oauthResult->OauthChunk.ClientAttrs = std::move(clientAttrsChunk);

            // Do not make ticket without uid

            return oauthResult;
        }

        const ui64 uidInt = TUtils::ToUInt(tokenInfo->Uid, TStrings::UID);

        // if we didn't fetch user data before, let's do it now
        if (uid.empty()) {
            fetcher.FetchByUid(tokenInfo->Uid);
        }

        const TDbProfile* profile = fetcher.NextProfile();

        if (nullptr == profile) {
            error.SetError(TOAuthError::AccountNotFound);
            return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::USER_NOT_FOUND);
        }

        if (!profile->Get(availableIdx)->AsBoolean()) {
            error.SetError(TOAuthError::AccountDisabled);
            return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::USER_DISABLED);
        }

        if (strongPwd.PasswdExpired(profile, Blackbox_.StrongPwdExpireTime())) {
            error.SetError(TOAuthError::ExpiredPassword);
            return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::USER_PWD_EXPIRED);
        }

        const time_t issuedTime = tokenInfo->GetMinimumIssueTime();

        time_t glogoutTime = profile->Get(glogoutIdx)->AsTime();
        if (issuedTime < glogoutTime) {
            error.SetError(TOAuthError::GLogout);
            return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::USER_GLOGOUT);
        }

        time_t revokeTokensTime = profile->Get(revokerTokensIdx)->AsTime();
        if (issuedTime < revokeTokensTime) {
            error.SetError(TOAuthError::Revoked);
            return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::USER_REVOKED_TOKENS);
        }

        const TInstant domainGlogout = profile->PddDomItem().Glogout();
        if (issuedTime < domainGlogout.TimeT()) {
            error.SetError(TOAuthError::GLogout);
            return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::USER_GLOGOUT);
        }

        if (profile->Get(changeReasonIdx)->AsBoolean()) {
            error.SetError(TOAuthError::PasswordChangeRequired);
            return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::USER_PWD_CREATE);
        }

        if (profile->Get(createRequiredIdx)->AsBoolean()) {
            error.SetError(TOAuthError::PasswordCreateRequired);
            return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::USER_PWD_CHANGE);
        }

        if (!profile->Get(federalIdx)->Value.empty() && !TUtils::GetBoolArg(Request_, TStrings::ALLOW_FEDERAL)) {
            if (TExperiment::Get().RestrictFederalUsers) {
                error.SetError(TOAuthError::FederalAccount);
                return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::USER_FEDERAL);
            }
            TLog::Debug() << "Federal account not allowed in method=oauth. "
                          << "Uid=" << profile->Uid() << ", "
                          << "consumer=" << consumer.GetName();
        }

        if (IsRobotFromExternalNetwork(uidInt)) {
            if (!Blackbox_.IsExternalRobotAllowed(uidInt)) {
                error.SetError(TOAuthError::ExternalRobot);
                return GetErrorOAuth(error, tokenInfo.get(), TOAuthStatboxMsg::USER_ROBOT);
            }

            LogRobotCheck(*tokenInfo, consumer);
        }

        if (Blackbox_.AuthLogger()) {
            Blackbox_.AuthLogger()->Write(
                tokenInfo->Uid,
                TStrings::EMPTY,
                TStrings::EMPTY,
                TStrings::OAUTHCHECK,
                TOAuthStatus::GetAuthLogStatus(TOAuthStatus::Valid),
                tokenInfo->AuthLogComment(options.UserPort, consumer.GetClientId()),
                TStrings::EMPTY,
                false,
                UserIp_,
                TStrings::EMPTY,
                TStrings::EMPTY,
                TStrings::EMPTY,
                TStrings::EMPTY,
                TStrings::EMPTY);
        }

        Blackbox_.OauthConfig().LogOAuth(*tokenInfo, UserIp_, TStrings::OK, TStrings::EMPTY, Request_.GetConsumerFormattedName(), Request_.GetRemoteAddr());

        std::unique_ptr<TAttributesChunk> tokenAttrsChunk;
        std::unique_ptr<TAttributesChunk> clientAttrsChunk;
        if (attrsHelper) {
            tokenAttrsChunk = attrsHelper->TokenAttrsChunk(tokenInfo.get());
            clientAttrsChunk = attrsHelper->ClientAttrsChunk(tokenInfo.get());
        }

        std::unique_ptr<TOAuthResult> oauthResult = std::make_unique<TOAuthResult>();

        if (options.GetUserTicket) {
            oauthResult->UserTicket = BuildUserTicket(
                TUtils::ToUInt(profile->Uid(), TStrings::UID),
                consumer,
                tokenInfo->GetScopeCollection());
        }
        oauthResult->OauthChunk.TokenInfo = std::move(tokenInfo);

        oauthResult->OauthChunk.Status = TOAuthStatus::Valid;
        oauthResult->OauthChunk.Comment = "OK";
        oauthResult->ConnectionId = connectionId;
        if (options.GetLoginId) {
            oauthResult->OauthChunk.LoginId = oauthResult->OauthChunk.TokenInfo->GetLoginId();
        }

        oauthResult->OauthChunk.TokenAttrs = std::move(tokenAttrsChunk);
        oauthResult->OauthChunk.ClientAttrs = std::move(clientAttrsChunk);

        oauthResult->Uid = TUidHelper::Result(profile, false, profile->PddDomItem());
        if (baseResult) {
            baseResult->FillResults(*oauthResult, profile);
        }

        return oauthResult;
    }

    std::unique_ptr<TOAuthResult> TOAuthProcessor::GetErrorOAuth(const TOAuthError& error,
                                                                 const TOAuthTokenInfo* tokenInfo,
                                                                 const TString& statboxMsg) {
        if (tokenInfo) {
            TLog::Debug("Error checking OAuth token: %s. uid=%s tokid=%s",
                        error.Msg().c_str(),
                        tokenInfo->Uid.c_str(),
                        tokenInfo->TokenId.c_str());
            Blackbox_.OauthConfig().LogOAuth(*tokenInfo, UserIp_, TOAuthStatboxMsg::ERROR, statboxMsg, Request_.GetConsumerFormattedName(), Request_.GetRemoteAddr());
        }

        const TOAuthStatus status = error.ConvertToStatus();

        std::unique_ptr<TOAuthResult> oauthResult = std::make_unique<TOAuthResult>();
        oauthResult->OauthChunk.Status = status;
        oauthResult->OauthChunk.Comment = error.Msg();

        // write to auth.log all tokens with uid
        if (Blackbox_.AuthLogger() && tokenInfo && !tokenInfo->Uid.empty()) {
            Blackbox_.AuthLogger()->Write(
                tokenInfo->Uid,
                TStrings::EMPTY,
                TStrings::EMPTY,
                TStrings::OAUTHCHECK,
                status.GetAuthLogStatus(),
                error.AuthLogErrorComment(tokenInfo->TokenId),
                TStrings::EMPTY,
                false,
                UserIp_,
                TStrings::EMPTY,
                TStrings::EMPTY,
                TStrings::EMPTY,
                TStrings::EMPTY,
                TStrings::EMPTY);
        }

        return oauthResult;
    }

    bool TOAuthProcessor::IsRobotFromExternalNetwork(ui64 uid) const {
        if (!Blackbox_.StaffInfo()) {
            return false;
        }

        if (NUtils::TIpAddr(UserIp_).IsLoopback()) {
            // it is legal for background activity
            return false;
        }

        if (!Blackbox_.StaffInfo()->IsRobot(uid)) {
            return false;
        }

        const ENetworkKind kind = Blackbox_.CheckYandexIp(UserIp_, Request_, ToString(uid));
        return kind == ENetworkKind::External;
    }

    void TOAuthProcessor::LogRobotCheck(const TOAuthTokenInfo& tokenInfo, const TConsumer& consumer) const {
        TLog::Debug() << "OAuthProcessor: token of robot is valid: uid=" << tokenInfo.Uid
                      << "; userip=" << UserIp_
                      << "; tokenid=" << tokenInfo.TokenId
                      << "; consumer=" << consumer.GetName()
                      << "; consumer_ip=" << Request_.GetRemoteAddr() << ";";
    }

    TString TOAuthProcessor::BuildUserTicket(ui64 uid,
                                             const TConsumer& consumer,
                                             const std::unordered_set<TString>& scopes) const {
        NTicketSigner::TUserSigner userSigner;

        userSigner.AddUid(uid);
        userSigner.SetDefaultUid(uid);
        userSigner.SetEntryPoint(consumer.GetClientId());
        userSigner.SetEnv(Blackbox_.UserTicketEnv());
        for (const TString& p : scopes) {
            userSigner.AddScope(p);
        }

        TUserTicketCache::TCacheHolder cache;
        if (Blackbox_.UserTicketCacheForOAuth()) {
            cache = Blackbox_.UserTicketCacheForOAuth()->GetCache();
        }

        NCache::EStatus status;
        TString res = cache.GetValue(userSigner, status);
        if (res) {
            return res;
        }

        res = userSigner.SerializeV3(
            *Blackbox_.TvmPrivateKeys().GetKey(),
            time(nullptr) + Blackbox_.UserTicketTtl() + Blackbox_.UserTicketCacheTtl());
        cache.PutValue(std::move(userSigner), TString(res));

        return res;
    }
}
