#include "auth_gatekeeper.h"

#include <solomon/services/fetcher/lib/clients/access_service.h>
#include <solomon/services/fetcher/lib/config_updater/config_updater.h>

#include <solomon/libs/cpp/logging/logging.h>

#include <library/cpp/actors/core/actor_bootstrapped.h>
#include <library/cpp/actors/core/hfunc.h>
#include <library/cpp/containers/absl_flat_hash/flat_hash_map.h>

#include <util/digest/multi.h>

#include <functional>

namespace NSolomon::NFetcher {
namespace {

using namespace NActors;
using namespace NSolomon::NAuth;
using yandex::solomon::common::UrlStatusType;

struct TCheckKey {
    TString Provider;
    /**
     * IAM token or TVM Service Ticket 
     */
    TString Credentials;
    TString Folder;
    TString Url;

    bool operator==(const TCheckKey&) const noexcept = default;
};

struct TAuthorizationKey {
    TString Provider;
    TString ServiceAccount;
    TString Folder;

    bool operator==(const TAuthorizationKey&) const noexcept = default;
};

} // namespace
} // namespace NSolomon::NFetcher

template <>
struct std::hash<NSolomon::NFetcher::TCheckKey> {
    size_t operator()(const NSolomon::NFetcher::TCheckKey& key) {
        return MultiHash(key.Provider, key.Credentials, key.Folder, key.Url);
    }
};

template <>
struct std::hash<NSolomon::NFetcher::TAuthorizationKey> {
    size_t operator()(const NSolomon::NFetcher::TAuthorizationKey& key) {
        return MultiHash(key.Provider, key.ServiceAccount, key.Folder);
    }
};

namespace NSolomon::NFetcher {
namespace {

constexpr TStringBuf AUTH_OVERFLOW = "auth from previous data is still in progress. Dropping old data";

/**
 * For authentication: does a service account belong to a provider
 * For authorization: does a service account have the write permission
 */
enum ECheckState {
    Success,
    Fail,
};

struct TCheckDecision {
    ECheckState CheckState;
    TString Error;
};

/**
 * Context of a global check. All checks start from here. It also contains a cache of the latest decision.
 * Different checks could fail on different steps (authentication or authorization), so every step will also
 * propagate an error to this object to create a shortcut for next checks:
 *
 *    CheckContext -> Authentication -> Authorization ->    Ok
 *        ^            fail               fail           success
 *        |\            v                  v                v
 *        | \__________/                   |                |
 *        |                                |                |
 *        |\______________________________/                /
 *        |                                               /
 *         \_____________________________________________/
 */
struct TCheckContext {
    TCheckKey Key{};
    TInstant LastAccessedAt{};
    bool IsInflight{false};
    std::optional<TCheckDecision> LastDecision{};
    TInstant DecidedAt{};
    TActorId ReplyToWithStatus{};
    std::unique_ptr<TShardData> ShardData{};
};

struct TAuthenticationResponse {
    ECheckState CheckState;
    /**
     * non-empty iff CheckState == Success
     */
    TString ServiceAccount;
    /**
     * non-empty iff CheckState == Fail
     */
    TString Error;
};

struct TAuthentication {
    /**
     * IAM token or TVM Service Ticket
     */
    TString Credentials;
    TString Provider;
    std::optional<TAuthenticationResponse> LastResponse;
    TInstant ReceivedResponseAt;
    TInstant LastAccessedAt;
    bool IsInflight{false};

    TVector<TCheckContext*> Subscribers;
};

struct TAuthorizationResponse {
    ECheckState CheckState;
    TString Error;
};

struct TAuthorization {
    TAuthorizationKey Key;
    TInstant LastAccessedAt;
    bool IsInflight{false};
    std::optional<TAuthorizationResponse> LastResponse;
    TInstant ReceivedResponseAt;
    TVector<TCheckContext*> Subscribers;
};

struct TSingleProviderMetrics {
    TSingleProviderMetrics(const TString& provider, NMonitoring::TMetricRegistry& registry)
        : AuthenticationSuccess{registry.Rate(
            {
                {"sensor", "multishardPulling.authentication"},
                {"service_provider", provider },
                {"status", "OK"},
            })}
        , AuthenticationFailure{registry.Rate(
            {
                {"sensor", "multishardPulling.authentication"},
                {"service_provider", provider },
                {"status", "ERROR"},
            })}
        , AuthenticationNotFound{registry.Rate(
            {
                {"sensor", "multishardPulling.authentication"},
                {"service_provider", provider },
                {"status", "NOT_FOUND"},
            })}
        , AuthorizationSuccess{registry.Rate(
            {
                {"sensor", "multishardPulling.authorization"},
                {"service_provider", provider },
                {"status", "OK"},
            })}
        , AuthorizationFailure{registry.Rate(
            {
                {"sensor", "multishardPulling.authorization"},
                {"service_provider", provider },
                {"status", "ERROR"},
            })}
    {
    }

    NMonitoring::TRate* AuthenticationSuccess;
    NMonitoring::TRate* AuthenticationFailure;
    NMonitoring::TRate* AuthenticationNotFound;

    NMonitoring::TRate* AuthorizationSuccess;
    NMonitoring::TRate* AuthorizationFailure;

    TInstant LastAccessedAt{};
};

class TProviderMetrics {
public:
    TProviderMetrics(std::shared_ptr<NMonitoring::TMetricRegistry> registry)
        : Registry_{std::move(registry)}
        , MetricsTotal_{"total", *Registry_}
    {
    }

    void DeleteUnusedMetrics(TInstant now, TDuration unusedMetricsTtl) {
        for (auto it = ProviderMetrics_.begin(); it != ProviderMetrics_.end();) {
            if (now - it->second.LastAccessedAt > unusedMetricsTtl) {
                ProviderMetrics_.erase(it++);
            } else {
                ++it;
            }
        }
    }

    inline TSingleProviderMetrics& GetOrCreate(const TString& provider, TInstant now) {
        auto it = ProviderMetrics_.try_emplace(provider, provider, *Registry_).first;

        auto& metrics = it->second;
        metrics.LastAccessedAt = now;

        return metrics;
    }

    void OnAuthenticationSuccess(const TString& provider, TInstant now) {
        GetOrCreate(provider, now).AuthenticationSuccess->Inc();
        MetricsTotal_.AuthenticationSuccess->Inc();
    }

    void OnAuthenticationFailure(const TString& provider, TInstant now) {
        GetOrCreate(provider, now).AuthenticationFailure->Inc();
        MetricsTotal_.AuthenticationFailure->Inc();
    }

    void OnAuthenticationNotFound(const TString& provider, TInstant now) {
        GetOrCreate(provider, now).AuthenticationNotFound->Inc();
        MetricsTotal_.AuthenticationNotFound->Inc();
    }

    void OnAuthorizationSuccess(const TString& provider, TInstant now) {
        GetOrCreate(provider, now).AuthorizationSuccess->Inc();
        MetricsTotal_.AuthorizationSuccess->Inc();
    }

    void OnAuthorizationFailure(const TString& provider, TInstant now) {
        GetOrCreate(provider, now).AuthorizationFailure->Inc();
        MetricsTotal_.AuthorizationFailure->Inc();
    }

private:
    std::shared_ptr<NMonitoring::TMetricRegistry> Registry_;
    absl::flat_hash_map<TString, NSolomon::NFetcher::TSingleProviderMetrics> ProviderMetrics_;
    TSingleProviderMetrics MetricsTotal_;
};

struct TCacheMetrics {
    explicit TCacheMetrics(NMonitoring::TMetricRegistry& registry)
        : ExamineHit{registry.Rate({ {"sensor", "auth.cache.examineHit"} })}
        , ExamineMiss{registry.Rate({ {"sensor", "auth.cache.examineMiss"} })}
        , AuthenticationHit{registry.Rate({ {"sensor", "auth.cache.authenticationHit"} })}
        , AuthenticationMiss{registry.Rate({ {"sensor", "auth.cache.authenticationMiss"} })}
        , AuthorizationHit{registry.Rate({ {"sensor", "auth.cache.authorizationHit"} })}
        , AuthorizationMiss{registry.Rate({ {"sensor", "auth.cache.authorizationMiss"} })}
    {
    }

    NMonitoring::TRate* ExamineHit;
    NMonitoring::TRate* ExamineMiss;
    NMonitoring::TRate* AuthenticationHit;
    NMonitoring::TRate* AuthenticationMiss;
    NMonitoring::TRate* AuthorizationHit;
    NMonitoring::TRate* AuthorizationMiss;
};

class TPrivateEvents: private NSolomon::TPrivateEvents {
    enum {
        EvDeleteUnusedRecords = SpaceBegin,
        EvTvmAuthenticationSuccess,
        EvTvmAuthenticationFailure,
        End,
    };
    static_assert(SpaceBegin < End, "too many event types");

public:
    struct TDeleteUnusedRecords: NActors::TEventLocal<TDeleteUnusedRecords, EvDeleteUnusedRecords> {
    };

    class TTvmAuthenticationSuccess: public TEventLocal<TTvmAuthenticationSuccess, EvTvmAuthenticationSuccess> {
    public:
        explicit TTvmAuthenticationSuccess(TAuthSubject subj)
            : Subject_{std::move(subj)}
        {}

        const TTvmSubject& TvmSubject() const {
            return Subject_.AsTvm();
        }

    private:
        TAuthSubject Subject_;
    };

    struct TTvmAuthenticationFailure: public TEventLocal<TTvmAuthenticationFailure, EvTvmAuthenticationFailure> {
        TString Message;

        explicit TTvmAuthenticationFailure(TString msg)
            : Message{std::move(msg)}
        {}
    };

};

class TIamAuthGatekeeper;
class TTvmAuthGatekeeper;

template <class TDerived>
class TAuthGatekeeper: public TActorBootstrapped<TAuthGatekeeper<TDerived>> {
public:
    TAuthGatekeeper(
            TActorId configUpdater,
            TActorId sink,
            TDuration unusedRecordsTtl,
            TDuration cacheDuration,
            TDuration gcInterval,
            std::shared_ptr<NMonitoring::TMetricRegistry> registry)
        : ConfigUpdater_{configUpdater}
        , Sink_{sink}
        , UnusedRecordsTtl_{unusedRecordsTtl}
        , CacheDuration_{cacheDuration}
        , GcInterval_{gcInterval}
        , CacheMetrics_{*registry}
        , ProviderMetrics_{std::move(registry)}
    {}

    void Bootstrap() {
        this->Become(&TDerived::Main);

        this->Schedule(GcInterval_, new TPrivateEvents::TDeleteUnusedRecords);
        this->Send(ConfigUpdater_, new TEvents::TEvSubscribe);
    }

protected:
    TCheckContext& GetOrCreateCheckContext(TCheckKey key) {
        auto& req = Checks_[key];
        if (!req) {
            req = std::make_unique<TCheckContext>();
            req->Key = std::move(key);
        }

        req->LastAccessedAt = TActivationContext::Now();

        return *req;
    }

    void ReplyWithError(TCheckContext& checkCtx) {
        auto errEv = std::make_unique<TEvMetricDataWritten>(
                checkCtx.ShardData->ShardId.StrId(),
                UrlStatusType::AUTH_ERROR,
                0,
                checkCtx.LastDecision->Error);
        this->Send(checkCtx.ReplyToWithStatus, errEv.release());

        checkCtx.ReplyToWithStatus = {};
        checkCtx.ShardData = {};
    }

    void PassDataToSink(TCheckContext& checkCtx) {
        auto ev = std::make_unique<TEvSinkWrite>(std::move(*checkCtx.ShardData));
        TActivationContext::Send(new IEventHandle{Sink_, checkCtx.ReplyToWithStatus, ev.release()});

        checkCtx.ReplyToWithStatus = {};
        checkCtx.ShardData = {};
    }

    void OnExamine(const TAuthGatekeeperEvents::TExamine::TPtr& evPtr) {
        auto& ev = *evPtr->Get();
        Y_VERIFY(!ev.ShardData->ClusterName.empty(), "cluster name cannot be empty");

        auto& checkCtx = GetOrCreateCheckContext(TCheckKey{
                .Provider = ev.Provider,
                .Credentials = ev.Credentials,
                .Folder = ev.ShardData->ClusterName,
                .Url = ev.Url,
        });
        checkCtx.ReplyToWithStatus = evPtr->Sender;
        checkCtx.ShardData = std::move(ev.ShardData);

        if (checkCtx.IsInflight) {
            auto errEv = std::make_unique<TEvMetricDataWritten>(
                    checkCtx.ShardData->ShardId.StrId(),
                    UrlStatusType::AUTH_ERROR,
                    0,
                    TString{AUTH_OVERFLOW});
            this->Send(checkCtx.ReplyToWithStatus, errEv.release());
            return;
        }

        bool isFresh = (TActivationContext::Now() - checkCtx.DecidedAt) <= CacheDuration_;
        if (checkCtx.LastDecision && isFresh) {
            CacheMetrics_.ExamineHit->Inc();

            if (checkCtx.LastDecision->CheckState == ECheckState::Fail) {
                ReplyWithError(checkCtx);
                return;
            }

            PassDataToSink(checkCtx);
            return;
        } // else -- no cache or it became stale

        CacheMetrics_.ExamineMiss->Inc();
        checkCtx.LastDecision = {};

        checkCtx.IsInflight = true;
        CheckAuthentication(checkCtx);
    }

    TAuthentication& GetOrCreateAuthentication(const TString& token, const TString& provider) {
        auto& authentication = AuthenticationRequests_[token];
        if (!authentication) {
            authentication = std::make_unique<TAuthentication>();
            authentication->Credentials = token;
            authentication->Provider = provider;
        }

        authentication->LastAccessedAt = TActivationContext::Now();

        return *authentication;
    }

    void AuthenticationPropagateError(TAuthentication& auth) {
        if constexpr (std::is_same_v<TDerived, TIamAuthGatekeeper>) {
            MON_DEBUG(
                    AuthGatekeeper,
                    "failed to authenticate an IAM token for provider " << auth.Provider
                                                                       << ". reason: " << auth.LastResponse->Error);
        } else if constexpr (std::is_same_v<TDerived, TTvmAuthGatekeeper>) {
            MON_DEBUG(
                    AuthGatekeeper,
                    "failed to authenticate a TVM ticket for provider " << auth.Provider
                                                                        << ". reason: " << auth.LastResponse->Error);
        } else {
            static_assert(TDependentFalse<TDerived>, "unknown auth type");
        }

        auto now = TActivationContext::Now();

        for (auto* checkCtx: auth.Subscribers) {
            checkCtx->LastDecision = {
                    .CheckState = ECheckState::Fail,
                    .Error = auth.LastResponse->Error,
            };
            checkCtx->DecidedAt = now;
            checkCtx->IsInflight = false;

            ReplyWithError(*checkCtx);
        }

        auth.Subscribers.clear();
    }

    void CheckAuthentication(TCheckContext& checkCtx) {
        auto& auth = GetOrCreateAuthentication(checkCtx.Key.Credentials, checkCtx.Key.Provider);
        auth.Subscribers.emplace_back(&checkCtx);

        if (auth.IsInflight) {
            return;
        }

        auto* self = static_cast<TDerived*>(this);

        bool isFresh = (TActivationContext::Now() - auth.ReceivedResponseAt) <= CacheDuration_;
        if (auth.LastResponse && isFresh) {
            CacheMetrics_.AuthenticationHit->Inc();

            if (auth.LastResponse->CheckState == ECheckState::Fail) {
                AuthenticationPropagateError(auth);
                return;
            }

            if constexpr (std::is_same_v<TDerived, TTvmAuthGatekeeper>) {
                // TODO(SOLOMON-8818): instead of AuthorizationPropagateSuccess() use:
                // self->CheckThatServiceHasWritePermission(checkCtx, auth.LastResponse->ServiceAccount);
                self->AuthorizationPropagateSuccess(auth);
            } else if constexpr (std::is_same_v<TDerived, TIamAuthGatekeeper>) {
                self->CheckThatServiceHasWritePermission(checkCtx, auth.LastResponse->ServiceAccount);
            } else {
                static_assert(TDependentFalse<TDerived>, "unknown auth type");
            }
            auth.Subscribers.clear();
            return;
        } // else -- no cache or it became stale

        CacheMetrics_.AuthenticationMiss->Inc();
        auth.LastResponse = {};

        auth.IsInflight = true;
        self->SendAuthenticationRequest(auth);
    }

    template <typename TEvent>
    void OnAuthenticationFail(const typename TEvent::TPtr& evPtr) {
        auto& auth = *reinterpret_cast<TAuthentication*>(evPtr->Cookie);

        auth.IsInflight = false;
        auth.LastResponse = {
                .CheckState = ECheckState::Fail,
                .Error = std::move(evPtr->Get()->Message),
        };
        auth.ReceivedResponseAt = TActivationContext::Now();

        ProviderMetrics_.OnAuthenticationFailure(auth.Provider, TActivationContext::Now());
        AuthenticationPropagateError(auth);
    }

    template <typename TEvent>
    void OnAuthenticationSuccess(const typename TEvent::TPtr& evPtr) {
        auto* self = static_cast<TDerived*>(this);

        auto& auth = *reinterpret_cast<TAuthentication*>(evPtr->Cookie);
        auto& subject = self->RetrieveSubject(evPtr);
        auth.IsInflight = false;
        auth.ReceivedResponseAt = TActivationContext::Now();

        if (!self->IsSubjectValid(subject, auth)) {
            return;
        }

        const auto& serviceId = self->RetrieveServiceId(evPtr);

        // this method will update provider metrics by itself
        if (auto err = self->DoesServiceBelongToProvider(serviceId, auth.Provider)) {
            auth.LastResponse = {
                    .CheckState = ECheckState::Fail,
                    .Error = std::move(err),
            };

            AuthenticationPropagateError(auth);
            return;
        }

        auth.LastResponse = TAuthenticationResponse{
                .CheckState = ECheckState::Success,
        };
        if constexpr (std::is_same_v<TDerived, TIamAuthGatekeeper>) {
            auth.LastResponse->ServiceAccount = serviceId;
        }

        for (auto* checkCtx: auth.Subscribers) {
            if constexpr (std::is_same_v<TDerived, TIamAuthGatekeeper>) {
                self->CheckThatServiceHasWritePermission(*checkCtx, serviceId);
            } else if constexpr (std::is_same_v<TDerived, TTvmAuthGatekeeper>) {
                // TODO(SOLOMON-8818): instead use:
                // self->CheckThatServiceHasWritePermission(*checkCtx, serviceId);
                self->AuthorizationPropagateSuccess(auth);
                Y_UNUSED(checkCtx);
            } else {
                static_assert(TDependentFalse<TDerived>, "unknown auth type");
            }
        }
        auth.Subscribers.clear();
    }

    TAuthorization& GetOrCreateAuthorization(TAuthorizationKey key) {
        auto& authorization = AuthorizationRequests_[key];
        if (!authorization) {
            authorization = std::make_unique<TAuthorization>();
            authorization->Key = std::move(key);
        }

        authorization->LastAccessedAt = TActivationContext::Now();

        return *authorization;
    }

    void AuthorizationPropagateError(TAuthorization& auth) {
        auto now = TActivationContext::Now();

        for (auto* checkCtx: auth.Subscribers) {
            checkCtx->LastDecision = {
                    .CheckState = ECheckState::Fail,
                    .Error = auth.LastResponse->Error,
            };
            checkCtx->DecidedAt = now;
            checkCtx->IsInflight = false;

            ReplyWithError(*checkCtx);
        }

        auth.Subscribers.clear();
    }

    template <typename TContainer>
    void DeleteUnusedRecords(TContainer& reqs, TInstant now, TDuration unusedRecordsTtl) {
        for (auto it = reqs.begin(); it != reqs.end();) {
            if (now - it->second->LastAccessedAt > unusedRecordsTtl && !it->second->IsInflight) {
                reqs.erase(it++);
            } else {
                ++it;
            }
        }
    }

    void OnDeleteUnusedRecords() {
        auto ev = std::make_unique<TPrivateEvents::TDeleteUnusedRecords>();
        this->Schedule(GcInterval_, ev.release());

        auto now = TActivationContext::Now();

        DeleteUnusedRecords(Checks_, now, UnusedRecordsTtl_);
        DeleteUnusedRecords(AuthenticationRequests_, now, UnusedRecordsTtl_);
        DeleteUnusedRecords(AuthorizationRequests_, now, UnusedRecordsTtl_);
        ProviderMetrics_.DeleteUnusedMetrics(now, UnusedRecordsTtl_);
    }

    void OnProvidersUpdate(const TEvProvidersChanged::TPtr& evPtr) {
        const auto& ev = *(evPtr->Get());

        for (const auto& provider: ev.Added) {
            Providers_[provider->Id] = provider;
        }

        for (const auto& provider: ev.Changed) {
            Providers_[provider->Id] = provider;
        }

        for (const auto& providerId: ev.Removed) {
            Providers_.erase(providerId);
        }
    }

protected:
    TActorId ConfigUpdater_;
    TActorId Sink_;
    TDuration UnusedRecordsTtl_;
    TDuration CacheDuration_;
    TDuration GcInterval_;
    absl::flat_hash_map<TCheckKey, std::unique_ptr<TCheckContext>> Checks_;
    absl::flat_hash_map<TString, std::unique_ptr<TAuthentication>> AuthenticationRequests_;
    absl::flat_hash_map<TAuthorizationKey, std::unique_ptr<TAuthorization>> AuthorizationRequests_;
    THashMap<TProviderId, TProviderConfigPtr> Providers_;
    TCacheMetrics CacheMetrics_;
    TProviderMetrics ProviderMetrics_;
};

class TIamAuthGatekeeper: public TAuthGatekeeper<TIamAuthGatekeeper> {
    using TBase = TAuthGatekeeper<TIamAuthGatekeeper>;
public:
    TIamAuthGatekeeper(
            TActorId accessService,
            TActorId configUpdater,
            TActorId sink,
            TDuration unusedRecordsTtl,
            TDuration cacheDuration,
            TDuration gcInterval,
            bool dontTakeAuthorizationIntoAccount,
            std::shared_ptr<NMonitoring::TMetricRegistry> registry)
        : TBase(configUpdater, sink, unusedRecordsTtl, cacheDuration, gcInterval, std::move(registry))
        , AccessService_{accessService}
        , DontTakeAuthorizationIntoAccount_{dontTakeAuthorizationIntoAccount}
    {
    }

    STATEFN(Main) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TAuthGatekeeperEvents::TExamine, OnExamine);
            hFunc(TAccessServiceEvents::TAuthenticationFailure,
                  OnAuthenticationFail<TAccessServiceEvents::TAuthenticationFailure>);
            hFunc(TAccessServiceEvents::TAuthenticationSuccess,
                  OnAuthenticationSuccess<TAccessServiceEvents::TAuthenticationSuccess>);
            hFunc(TAccessServiceEvents::TAuthorizationFailure, OnAuthorizationFail);
            hFunc(TAccessServiceEvents::TAuthorizationSuccess, OnAuthorizationSuccess);

            hFunc(TEvProvidersChanged, OnProvidersUpdate);
            sFunc(TPrivateEvents::TDeleteUnusedRecords, OnDeleteUnusedRecords);
        }
    }

    void SendAuthenticationRequest(TAuthentication& auth) {
        auto authEv = std::make_unique<TAccessServiceEvents::TAuthenticate>();
        authEv->IamToken = auth.Credentials;

        ui64 cookie = reinterpret_cast<ui64>(&auth);
        Send(AccessService_, authEv.release(), 0, cookie);
    }

    const TIamAccount& RetrieveSubject(const TAccessServiceEvents::TAuthenticationSuccess::TPtr& evPtr) {
        return evPtr->Get()->Account;
    }

    const TString& RetrieveServiceId(const TAccessServiceEvents::TAuthenticationSuccess::TPtr& evPtr) {
        return evPtr->Get()->Account.Id;
    }

    /**
     * @return if successful, returns an empty string. Otherwise non-empty error description
     */
    TString DoesServiceBelongToProvider(const TString& serviceAccount, const TString& provider) {
        auto it = Providers_.find(provider);
        if (it == Providers_.end()) {
            ProviderMetrics_.OnAuthenticationNotFound(provider, TActivationContext::Now());

            return TStringBuilder() << "no provider with id: \"" << provider << "\"";
        } else if (const auto& ids = it->second->IamServiceAccountIds; Find(ids, serviceAccount) == ids.end()) {
            ProviderMetrics_.OnAuthenticationFailure(provider, TActivationContext::Now());

            TStringBuilder sb{};
            sb << "failed to match a provider id \"" << provider << "\" with"
                    << " an iam service account. Account from a header: \"" << serviceAccount << "\";"
                    << " in memory: \"[";

            for (size_t i = 0; i != ids.size(); ++i) {
                if (i > 0) {
                    sb << ", ";
                }

                sb << '"' << ids[i] << '"';
            }
            sb << ']';

            return sb;
        }

        ProviderMetrics_.OnAuthenticationSuccess(provider, TActivationContext::Now());
        return {};
    }

    bool IsSubjectValid(const TIamAccount& subject, TAuthentication& auth) {
        if (subject.Type != EIamAccountType::Service) {
            TString errMsg = TStringBuilder()
                    << "wrong subject type. expected: " << EIamAccountType::Service << ", "
                    << "got: " << subject.Type
                    ;

            auth.LastResponse = {
                    .CheckState = ECheckState::Fail,
                    .Error = std::move(errMsg),
            };

            ProviderMetrics_.OnAuthenticationFailure(auth.Provider, TActivationContext::Now());
            AuthenticationPropagateError(auth);
            return false;
        }

        return true;
    }

    void AuthorizationPropagateSuccess(TAuthorization& auth) {
        auto now = TActivationContext::Now();

        for (auto* checkCtx: auth.Subscribers) {
            checkCtx->LastDecision = {
                .CheckState = ECheckState::Success,
                .Error = {},
            };
            checkCtx->DecidedAt = now;
            checkCtx->IsInflight = false;

            PassDataToSink(*checkCtx);
        }

        auth.Subscribers.clear();
    }

    void CheckThatServiceHasWritePermission(TCheckContext& checkCtx, const TString& serviceAccount) {
        auto& authorization = GetOrCreateAuthorization(TAuthorizationKey{
            .Provider = checkCtx.Key.Provider,
            .ServiceAccount = serviceAccount,
            .Folder = checkCtx.ShardData->ClusterName,
        });
        authorization.Subscribers.emplace_back(&checkCtx);

        if (authorization.IsInflight) {
            return;
        }

        bool isFresh = (TActivationContext::Now() - authorization.ReceivedResponseAt) <= CacheDuration_;
        if (authorization.LastResponse && isFresh) {
            CacheMetrics_.AuthorizationHit->Inc();

            if (authorization.LastResponse->CheckState == ECheckState::Fail) {
                if (DontTakeAuthorizationIntoAccount_) {
                    AuthorizationPropagateSuccess(authorization);
                } else {
                    AuthorizationPropagateError(authorization);
                }
                return;
            }

            AuthorizationPropagateSuccess(authorization);
            return;
        } // else -- no cache or it became stale

        CacheMetrics_.AuthorizationMiss->Inc();
        authorization.LastResponse = {};

        authorization.IsInflight = true;
        SendAuthorizationRequest(authorization);
    }

    void SendAuthorizationRequest(TAuthorization& auth) {
        auto authEv = std::make_unique<TAccessServiceEvents::TAuthorize>();
        authEv->ServiceAccountId = auth.Key.ServiceAccount;
        authEv->FolderId = auth.Key.Folder;

        ui64 cookie = reinterpret_cast<ui64>(&auth);
        Send(AccessService_, authEv.release(), 0, cookie);
    }

    void OnAuthorizationFail(const TAccessServiceEvents::TAuthorizationFailure::TPtr& evPtr) {
        auto& auth = *reinterpret_cast<TAuthorization*>(evPtr->Cookie);

        auth.IsInflight = false;
        auth.LastResponse = {
                .CheckState = ECheckState::Fail,
                .Error = std::move(evPtr->Get()->Message),
        };
        auto now = TActivationContext::Now();
        auth.ReceivedResponseAt = now;

        ProviderMetrics_.OnAuthorizationFailure(auth.Key.Provider, now);
        if (DontTakeAuthorizationIntoAccount_) {
            AuthorizationPropagateSuccess(auth);
        } else {
            AuthorizationPropagateError(auth);
        }
    }

    void OnAuthorizationSuccess(const TAccessServiceEvents::TAuthorizationSuccess::TPtr& evPtr) {
        auto& auth = *reinterpret_cast<TAuthorization*>(evPtr->Cookie);

        auto now = TActivationContext::Now();
        auth.IsInflight = false;
        auth.ReceivedResponseAt = now;

        auth.LastResponse = {
                .CheckState = ECheckState::Success,
                .Error = {},
        };

        ProviderMetrics_.OnAuthorizationSuccess(auth.Key.Provider, now);
        AuthorizationPropagateSuccess(auth);
    }

private:
    TActorId AccessService_;
    /**
     * Write metrics anyway so we don't break multishard pulling until every ServiceProvider switches to a proper authorization
     *
     * For more info, look at https://nda.ya.ru/t/aKLQFSLx48PU3n
     */
    bool DontTakeAuthorizationIntoAccount_;
};

class TTvmAuthGatekeeper: public TAuthGatekeeper<TTvmAuthGatekeeper> {
    using TBase = TAuthGatekeeper<TTvmAuthGatekeeper>;
public:
    TTvmAuthGatekeeper(
            NAuth::IAuthenticatorPtr tvmAuthenticator,
            TActorId configUpdater,
            TActorId sink,
            TDuration unusedRecordsTtl,
            TDuration cacheDuration,
            TDuration gcInterval,
            std::shared_ptr<NMonitoring::TMetricRegistry> registry)
        : TBase(configUpdater, sink, unusedRecordsTtl, cacheDuration, gcInterval, std::move(registry))
        , TvmAuthenticator_{std::move(tvmAuthenticator)}
    {
    }

    STATEFN(Main) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TAuthGatekeeperEvents::TExamine, OnExamine);
            hFunc(TPrivateEvents::TTvmAuthenticationFailure,
                  OnAuthenticationFail<TPrivateEvents::TTvmAuthenticationFailure>);
            hFunc(TPrivateEvents::TTvmAuthenticationSuccess,
                  OnAuthenticationSuccess<TPrivateEvents::TTvmAuthenticationSuccess>);

            hFunc(TEvProvidersChanged, OnProvidersUpdate);
            sFunc(TPrivateEvents::TDeleteUnusedRecords, OnDeleteUnusedRecords);
        }
    }

    void SendAuthenticationRequest(TAuthentication& auth) {
        auto* as = TActivationContext::ActorSystem();
        auto self = SelfId();
        ui64 cookie = reinterpret_cast<ui64>(&auth);

        auto result = TvmAuthenticator_->Authenticate(TAuthToken{EAuthType::TvmService, auth.Credentials})
            .Subscribe([as, self, cookie](auto f) {
                try {
                    TAuthResult result = f.ExtractValueSync();

                    if (result.Success()) {
                        as->Send(new IEventHandle{
                                self,
                                self,
                                new TPrivateEvents::TTvmAuthenticationSuccess{std::move(result.Value())},
                                0,
                                cookie});
                    } else {
                        // TODO: retries
                        as->Send(new IEventHandle{
                                self,
                                self,
                                new TPrivateEvents::TTvmAuthenticationFailure{std::move(result.ExtractError().Message)},
                                0,
                                cookie});
                    }
                } catch (...) {
                    as->Send(new IEventHandle{
                            self,
                            self,
                            new TPrivateEvents::TTvmAuthenticationFailure{CurrentExceptionMessage()},
                            0,
                            cookie});
                }
            });
    }

    const TTvmSubject& RetrieveSubject(const TPrivateEvents::TTvmAuthenticationSuccess::TPtr& evPtr) {
        return evPtr->Get()->TvmSubject();
    }

    NTvmAuth::TTvmId RetrieveServiceId(const TPrivateEvents::TTvmAuthenticationSuccess::TPtr& evPtr) {
        return evPtr->Get()->TvmSubject().GetServiceTicket()->GetSrc();
    }

    /**
     * @return if successful, returns an empty string. Otherwise non-empty error description
     */
    TString DoesServiceBelongToProvider(const NTvmAuth::TTvmId& serviceId, const TString& provider) {
        auto it = Providers_.find(provider);
        if (it == Providers_.end()) {
            ProviderMetrics_.OnAuthenticationNotFound(provider, TActivationContext::Now());

            return TStringBuilder() << "no provider with id: \"" << provider << "\"";
        } else if (const auto& ids = it->second->TvmServiceIds; Find(ids, serviceId) == ids.end()) {
            ProviderMetrics_.OnAuthenticationFailure(provider, TActivationContext::Now());

            TStringBuilder sb{};
            sb << "failed to match a provider id \"" << provider << "\" with"
                    << " a TVM service. Service id from a header: \"" << serviceId << "\";"
                    << " in memory: \"[";

            for (size_t i = 0; i != ids.size(); ++i) {
                if (i > 0) {
                    sb << ", ";
                }

                sb << '"' << ids[i] << '"';
            }
            sb << ']';

            return sb;
        }

        ProviderMetrics_.OnAuthenticationSuccess(provider, TActivationContext::Now());
        return {};
    }

    bool IsSubjectValid(const TTvmSubject& subject, TAuthentication& auth) {
        if (!subject.IsServiceAccount()) {
            TString errMsg = "wrong subject type. expected a service account, but got a user one";

            auth.LastResponse = {
                    .CheckState = ECheckState::Fail,
                    .Error = std::move(errMsg),
            };

            ProviderMetrics_.OnAuthenticationFailure(auth.Provider, TActivationContext::Now());
            AuthenticationPropagateError(auth);
            return false;
        }

        return true;
    }

    // TODO(SOLOMON-8818): use TAuthorization when auth will be supported
    // void AuthorizationPropagateSuccess(TAuthorization& auth) {
    void AuthorizationPropagateSuccess(TAuthentication& auth) {
        auto now = TActivationContext::Now();

        for (auto* checkCtx: auth.Subscribers) {
            checkCtx->LastDecision = {
                .CheckState = ECheckState::Success,
                .Error = {},
            };
            checkCtx->DecidedAt = now;
            checkCtx->IsInflight = false;

            PassDataToSink(*checkCtx);
        }

        auth.Subscribers.clear();
    }

private:
    NAuth::IAuthenticatorPtr TvmAuthenticator_;
};

} // namespace

std::unique_ptr<NActors::IActor> CreateIamAuthGatekeeper(
        NActors::TActorId accessService,
        NActors::TActorId configUpdater,
        NActors::TActorId sink,
        TDuration unusedRecordsTtl,
        TDuration cacheDuration,
        TDuration gcInterval,
        bool dontTakeAuthorizationIntoAccount,
        std::shared_ptr<NMonitoring::TMetricRegistry> registry)  // NOLINT(performance-unnecessary-value-param): false positive
{
    return std::make_unique<TIamAuthGatekeeper>(
            accessService,
            configUpdater,
            sink,
            unusedRecordsTtl,
            cacheDuration,
            gcInterval,
            dontTakeAuthorizationIntoAccount,
            std::move(registry));
}

std::unique_ptr<NActors::IActor> CreateTvmAuthGatekeeper(
        NAuth::IAuthenticatorPtr tvmAuthenticator,
        NActors::TActorId configUpdater,
        NActors::TActorId sink,
        TDuration unusedRecordsTtl,
        TDuration cacheDuration,
        TDuration gcInterval,
        std::shared_ptr<NMonitoring::TMetricRegistry> registry)
{
    return std::make_unique<TTvmAuthGatekeeper>(
            std::move(tvmAuthenticator),
            configUpdater,
            sink,
            unusedRecordsTtl,
            cacheDuration,
            gcInterval,
            std::move(registry));
}

} // namespace NSolomon::NFetcher
