#include "authenticator.h"

#include <library/cpp/tvmauth/client/exception.h>

#include <library/cpp/http/misc/http_headers.h>

#include <util/string/split.h>

namespace NSolomon::NAuth {

namespace {

const THashMap<EAuthType, TStringBuf> KnownTokenKeys{
        {EAuthType::OAuth, "OAuth"},
        {EAuthType::Iam, "Bearer"},
        {EAuthType::TvmService, "X-Ya-Service-Ticket"},
        {EAuthType::TvmUser, "X-Ya-User-Ticket"}
};

TTokenParseResult ParseAuthorizationHeader(TStringBuf authHeader, EAuthType authType) {
    const TStringBuf& tokenKey = KnownTokenKeys.at(authType);
    bool hasTokenKey = false;
    TStringBuf tokenValue;
    auto splitObj = StringSplitter(authHeader)
            .SplitBySet(" \t")
            .SkipEmpty();

    size_t partsCnt = 0;
    for (TStringBuf headerPart: splitObj) {
        ++partsCnt;

        if (partsCnt == 1) {
            if (tokenKey == headerPart) {
                hasTokenKey = true;
            }
        } else if (partsCnt == 2) {
            tokenValue = headerPart;
        } else {
            break;
        }
    }

    if (partsCnt != 2) {
        return TTokenParseResult::FromError(ETokenParseError::InvalidFormat);
    }
    if (!hasTokenKey) {
        return TTokenParseResult::FromError(ETokenParseError::UnknownAuthType);
    }
    return TTokenParseResult::FromValue(authType, std::move(TString(tokenValue)));
}

TTokenParseResult ParseAuthToken(const IHeaders* httpHeaders, EAuthType authType) {
    auto authHeader = httpHeaders->Find(NHttpHeaders::AUTHORIZATION);
    if (!authHeader) {
        return TTokenParseResult::FromError(ETokenParseError::NoAuthHeader);
    }

    return ParseAuthorizationHeader(authHeader.GetRef(), authType);
}

TTokenParseResult ExtractTokenFromHeaders(const IHeaders* httpHeaders, EAuthType authType) {
    const TStringBuf& tokenKey = KnownTokenKeys.at(authType);
    auto tokenValue = httpHeaders->Find(tokenKey);
    if (!tokenValue) {
        return TTokenParseResult::FromError(ETokenParseError::NoAuthHeader);
    }
    return TTokenParseResult::FromValue(authType, TString(tokenValue.GetRef()));
}

/**
 * TIamAuthenticator ---------------------------------------------------------------------------------------------------
 */

class TIamAuthenticator : public IAuthenticator {
public:
    explicit TIamAuthenticator(IAccessServiceClientPtr iamServiceClient)
        : IamServiceClient_(std::move(iamServiceClient))
    {
    }

    TTokenParseResult GetToken(const IHeaders* httpHeaders) override {
        return ParseAuthToken(httpHeaders, EAuthType::Iam);
    }

    TAsyncAuthResult Authenticate(const TAuthToken& token) override {
        Y_ENSURE(token.Type == EAuthType::Iam, "IAM authenticator got not IAM token of type " << token.Type);
        auto responseFuture = IamServiceClient_->Authenticate(token.Value);
        return responseFuture.Apply([](const TAsyncAuthenticateResponse& response) {
            return CastAuthResponse(response); // copy response
        });
    }

    const TAuthTypeSet& GetTypes() const override {
        return AuthTypes_;
    }

private:
    static EAuthErrorType CastAuthErrorType(EAccessServiceAuthErrorType accessServiceErrorType) {
        switch (accessServiceErrorType) {
            case EAccessServiceAuthErrorType::InternalRetriable:
                return EAuthErrorType::Retriable;
            case EAccessServiceAuthErrorType::InternalNonRetriable:
                return EAuthErrorType::NonRetriable;
            case EAccessServiceAuthErrorType::FailedAuth:
                return EAuthErrorType::FailedAuth;
        }
        return EAuthErrorType::NonRetriable;
    }

    static TErrorOr<TAuthSubject, TAuthError> CastAuthResponse(TAsyncAuthenticateResponse response) {
        using TResult = TErrorOr<TAuthSubject, NSolomon::NAuth::TAuthError>;
        TErrorOr<TIamAccount, NSolomon::TAuthError> res = response.ExtractValueSync();
        if (res.Success()) {
            auto value = res.Extract();
            TAuthSubject authSubject{
                .Subject = TIamSubject{std::move(value)}
            };
            return TResult::FromValue(std::move(authSubject));
        } else {
            auto err = res.ExtractError();
            return TResult::FromError(NSolomon::NAuth::TAuthError{
                .Type = CastAuthErrorType(err.Type),
                .Message = std::move(err.Message)
            });
        }
    }

private:
    IAccessServiceClientPtr IamServiceClient_;
    static const TAuthTypeSet AuthTypes_;
};

const TAuthTypeSet TIamAuthenticator::AuthTypes_{EAuthType::Iam};

/**
 * TTvmAuthenticator ---------------------------------------------------------------------------------------------------
 */

class TTvmAuthenticator : public IAuthenticator {
public:
    TTvmAuthenticator(NTvm::ITvmClientPtr tvmClient, EAuthType type, TMaybe<::NTvmAuth::EBlackboxEnv> blackboxEnv = {})
        : TvmClient_(std::move(tvmClient))
        , AuthTypes_{type}
        , BlackboxEnv_{blackboxEnv}
    {
        Y_VERIFY(type == EAuthType::TvmService || type == EAuthType::TvmUser);
        Y_VERIFY(AuthTypes_.size() == 1);
    }

    TTokenParseResult GetToken(const IHeaders* httpHeaders) override {
        return ExtractTokenFromHeaders(httpHeaders, *AuthTypes_.begin());
    }

    TAsyncAuthResult Authenticate(const TAuthToken& token) override {
        auto res = DoAuthenticateSync(token);
        return NThreading::MakeFuture(std::move(res));
    }

    const TAuthTypeSet& GetTypes() const override {
        return AuthTypes_;
    }

private:
    TAuthResult DoAuthenticateSync(const TAuthToken& token) {
        try {
            return IsServiceAuthenticator() ? CheckServiceToken(token) : CheckUserToken(token);
        } catch (const ::NTvmAuth::TRetriableException& ex) {
            return CreateAuthError(ex, EAuthErrorType::Retriable);
        } catch (const ::NTvmAuth::TNonRetriableException& ex) {
            return CreateAuthError(ex, EAuthErrorType::NonRetriable);
        } catch (const ::NTvmAuth::TNotAllowedException& ex) {
            return CreateAuthError(ex, EAuthErrorType::FailedAuth);
        } catch (const ::NTvmAuth::TTvmException& ex) {
            return CreateAuthError(ex, EAuthErrorType::NonRetriable);
        }
    }

    bool IsServiceAuthenticator() const {
        return AuthTypes_.contains(EAuthType::TvmService);
    }

    TAuthResult CheckServiceToken(const TAuthToken& token) {
        Y_ENSURE(token.Type == EAuthType::TvmService, "service TVM authenticator got token of type " << token.Type);
        auto res = TvmClient_->CheckServiceTicket(token.Value);
        const auto status = res.GetStatus();
        if (status != NTvmAuth::ETicketStatus::Ok) {
            return TAuthResult::FromError(TAuthError{
                    .Type = EAuthErrorType::FailedAuth,
                    .Message = TStringBuilder() << "Invalid status of TVM service ticket " << NTvmAuth::StatusToString(status) << res.DebugInfo()});
        } else {
            return TAuthResult::FromValue(TAuthSubject{
                    .Subject = TTvmSubject{std::move(res)}
            });
        }
    }

    TAuthResult CheckUserToken(const TAuthToken& token) {
        Y_ENSURE(token.Type == EAuthType::TvmUser, "user TVM authenticator got token of type " << token.Type);
        auto res = TvmClient_->CheckUserTicket(token.Value, BlackboxEnv_);
        const auto status = res.GetStatus();
        if (status != NTvmAuth::ETicketStatus::Ok) {
            return TAuthResult::FromError(TAuthError{
                    .Type = EAuthErrorType::FailedAuth,
                    .Message = TStringBuilder() << "Invalid status of TVM user ticket " << NTvmAuth::StatusToString(status)});
        } else {
            return TAuthResult::FromValue(TAuthSubject{
                    .Subject = TTvmSubject{std::move(res)}
            });
        }
    }

    TAuthResult CreateAuthError(const yexception& ex, EAuthErrorType type) {
        TString message(ex.AsStrBuf());
        TAuthError err {
                .Type = type,
                .Message = std::move(message)
        };
        return TAuthResult::FromError(std::move(err));
    }

private:
    NTvm::ITvmClientPtr TvmClient_;
    const TAuthTypeSet AuthTypes_;
    const TMaybe<::NTvmAuth::EBlackboxEnv> BlackboxEnv_;
};

/**
 * TMuxAuthenticator ---------------------------------------------------------------------------------------------------
 *
 * Authenticator multiplexer. Finds appropriate delegate authenticator implementation.
 */

class TMuxAuthenticator: public IAuthenticator {
public:
    explicit TMuxAuthenticator(const TVector<IAuthenticatorPtr>& authenticators) {
        for (const auto& authenticator: authenticators) {
            const auto& authTypes = authenticator->GetTypes();
            for (EAuthType type: authTypes) {
                Y_ENSURE(!AuthTypes_.contains(type), "multiple authenticators of type " << type);
                AuthTypes_.insert(type);
                Authenticators_[type] = authenticator;
            }
        }
    }

    TTokenParseResult GetToken(const IHeaders* httpHeaders) override {
        for (const auto& [_, authenticator]: Authenticators_) {
            auto token = authenticator->GetToken(httpHeaders);
            if (token.Success()) {
                return token;
            }
            if (token.Error().Type == ETokenParseError::InvalidFormat) {
                return token;
            }
        }
        return TTokenParseResult::FromError(ETokenParseError::NoAuthHeader);
    }

    TAsyncAuthResult Authenticate(const TAuthToken& token) override {
        const auto it = Authenticators_.find(token.Type);
        Y_ENSURE(it != Authenticators_.end(), "authenticator multiplexer got token of unsupported type " << token.Type);
        const auto& authenticator = it->second;
        return authenticator->Authenticate(token);
    }

    const TAuthTypeSet& GetTypes() const override {
        return AuthTypes_;
    }

private:
    // Priority of tokens should be according the order in enum EAuthType, that's why ordered
    std::map<EAuthType, IAuthenticatorPtr> Authenticators_;
    TAuthTypeSet AuthTypes_;
};

/**
 * TFakeAuthenticator --------------------------------------------------------------------------------------------------
 *
 * For unit tests
 */

class TFakeAuthenticator: public IAuthenticator {
public:
    explicit TFakeAuthenticator(EAuthType type)
        : AuthTypes_{type}
    {
    }

    TTokenParseResult GetToken(const IHeaders* httpHeaders) override {
        const auto authType = *AuthTypes_.begin();
        switch (authType) {
            case EAuthType::Unknown:
                Y_ENSURE(false, "Authentication type " << authType << " has no token key");
            case EAuthType::OAuth:
            case EAuthType::Iam:
                return ParseAuthToken(httpHeaders, authType);
            case EAuthType::TvmService:
            case EAuthType::TvmUser:
                return ExtractTokenFromHeaders(httpHeaders, authType);
            default:
                Y_ENSURE(false, "Parsing of token of type " << authType << " is not implemented yet");
        }
    }

    TAsyncAuthResult Authenticate(const TAuthToken& token) override {
        const EAuthType authType = *AuthTypes_.begin();
        Y_ENSURE(
                authType == token.Type,
                "fake authenticator got token of type " << token.Type << ", expected type " << authType);
        TAuthSubject authSubject{.Subject = TFakeAuthSubject{.AuthType = authType}};
        auto authResult = TAuthResult::FromValue(std::move(authSubject));
        return NThreading::MakeFuture(std::move(authResult));
    }

    const TAuthTypeSet& GetTypes() const override {
        return AuthTypes_;
    }

private:
    const TAuthTypeSet AuthTypes_;
};

} // namespace

TString TTokenParseError::GetMessage() const {
    switch (Type) {
        case ETokenParseError::NoAuthHeader:
            return "No authorization header found";
        case ETokenParseError::InvalidFormat:
            return "Invalid format of authorization header";
        case ETokenParseError::UnknownAuthType:
            return "Authorization header contains unknown type of token";
    }
    return {};
}

IAuthenticatorPtr CreateIamAuthenticator(IAccessServiceClientPtr iamServiceClient) {
    return std::make_shared<TIamAuthenticator>(std::move(iamServiceClient));
}

IAuthenticatorPtr CreateServiceTvmAuthenticator(NTvm::ITvmClientPtr tvmClient) {
    return std::make_shared<TTvmAuthenticator>(std::move(tvmClient), EAuthType::TvmService);
}

IAuthenticatorPtr CreateUserTvmAuthenticator(NTvm::ITvmClientPtr tvmClient, ::NTvmAuth::EBlackboxEnv blackboxEnv) {
    return std::make_shared<TTvmAuthenticator>(std::move(tvmClient), EAuthType::TvmUser, blackboxEnv);
}

IAuthenticatorPtr CreateFakeAuthenticator(EAuthType authType) {
    return std::make_shared<TFakeAuthenticator>(authType);
}

IAuthenticatorPtr CreateAuthenticatorMultiplexer(const TVector<IAuthenticatorPtr>& authenticators) {
    return std::make_shared<TMuxAuthenticator>(authenticators);
}

}
