#include "auth.h"

#include <drive/backend/common/localization.h>
#include <drive/backend/database/transaction/tx.h>

#include <kernel/daemon/common/time_guard.h>

#include <rtline/util/types/cache_with_age.h>

void TAuthRequestProcessor::CheckAuthInfo(TJsonReport::TGuard& g, IAuthInfo::TPtr authInfo) {
    Y_UNUSED(g);
    const auto& configHttpStatus = BaseServer->GetHttpStatusManagerConfig();
    if (!authInfo) {
        NDrive::TEventLog::Log("AuthorizationFailure", NJson::TMapBuilder
            ("auth_info", NJson::JSON_NULL)
        );
    }
    R_ENSURE(
        !!authInfo,
        configHttpStatus.UnauthorizedStatus,
        "null AuthInfo",
        EDriveSessionResult::Unauthenticated
    );

    NDrive::TInfoEntitySession session;
    if (!authInfo->IsAvailable()) {
        NDrive::TEventLog::Log("AuthorizationFailure", NJson::TMapBuilder
            ("auth_info", authInfo->GetInfo())
            ("code", authInfo->GetCode())
            ("user_id", authInfo->GetUserId())
        );
        session.SetError(NDrive::MakeError("auth_is_not_available"));
    }

    R_ENSURE(
        authInfo->IsAvailable(),
        authInfo->GetCode(configHttpStatus.UnauthorizedStatus),
        authInfo->GetMessage(),
        EDriveSessionResult::Unauthenticated,
        session
    );
}

namespace {
    TMutex MutexUserHandlerLastRequests;
    TMap<TString, TMap<TString, TInstant>> UserHandlerLastRequests;
}

template <class TKey, class TData>
class TCacheWithLiveTime {
private:
    class TTSKey {
    private:
        R_FIELD(TKey, Key);
        R_FIELD(TInstant, TS);

    public:
        TTSKey(const TKey& key, const TInstant ts)
            : Key(key)
            , TS(ts)
        {
        }
    };

    class TTSData {
    private:
        R_FIELD(TData, Data);
        R_FIELD(TInstant, TS);
    public:
        TTSData(TData&& data, const TInstant ts)
            : Data(std::move(data))
            , TS(ts)
        {
        }
    };

private:
    TMutex Mutex;
    TMap<TString, TTSData> Cache;
    std::deque<TTSKey> TSChecker;
    TDuration LiveTime = TDuration::Minutes(1);

public:
    TCacheWithLiveTime(const TDuration liveTime)
        : LiveTime(liveTime)
    {
    }

    TMaybe<TData> GetValue(const TKey& key) {
        TGuard<TMutex> g(Mutex);
        auto it = Cache.find(key);
        if (it == Cache.end()) {
            return TMaybe<TData>();
        } else {
            return it->second.GetData();
        }
    }

    void PutValue(const TKey& key, TData&& data) {
        TGuard<TMutex> g(Mutex);
        const TInstant current = Now();
        while (TSChecker.size() && (current - TSChecker.front().GetTS()) > LiveTime) {
            auto it = Cache.find(TSChecker.front().GetKey());
            if (it != Cache.end() && it->second.GetTS() == TSChecker.front().GetTS()) {
                Cache.erase(it);
            }
            TSChecker.pop_front();
        }
        auto it = Cache.find(key);
        if (it == Cache.end()) {
            Cache.emplace(key, TTSData(std::move(data), current));
        } else {
            it->second = TTSData(std::move(data), current);
        }
        TSChecker.emplace_back(key, current);
    }
};

template <class TKey, class TData>
class TCachesPool {
private:
    using TCache = TCacheWithLiveTime<TKey, TData>;

private:
    TMap<TString, TCache> Caches;
    TRWMutex RWMutex;
    TDuration LiveTime = TDuration::Minutes(1);

public:
    TCachesPool(const TDuration liveTime)
        : LiveTime(liveTime)
    {
    }

    TCache& GetCache(const TString& cacheId) {
        {
            TReadGuard rg(RWMutex);
            auto it = Caches.find(cacheId);
            if (it != Caches.end()) {
                return it->second;
            }
        }
        {
            TWriteGuard rg(RWMutex);
            return Caches.emplace(cacheId, TCache(LiveTime)).first->second;
        }
    }
};

namespace {
    TCachesPool<TString, TInstant> CachesPool(TDuration::Seconds(10));
}

void TAuthRequestProcessor::DoProcess(TJsonReport::TGuard& g) {
    IAuthInfo::TPtr authInfo;
    {
        TEventsGuard tgAuth(g.MutableReport(), "RestoreAuthInfo");
        authInfo = Auth->RestoreAuthInfo(Context);
        CheckAuthInfo(g, authInfo);
    }

    {
        double rpsLimit = GetHandlerSettingDef<double>("rps_limit", 0);
        if (rpsLimit > 0) {
            const double secondsOnRequest = 1 / rpsLimit;
            TCacheWithLiveTime<TString, TInstant>& cache = CachesPool.GetCache(HandlerName);
            TMaybe<TInstant> pred = cache.GetValue(authInfo->GetUserId());
            if (pred) {
                R_ENSURE((Context->GetRequestStartTime() - *pred).MicroSeconds() > secondsOnRequest * 1000000, ConfigHttpStatus.UserErrorState, "rps limit exceeded", NDrive::MakeError("rps_limit_exceeded"));
            }
            cache.PutValue(authInfo->GetUserId(), Context->GetRequestStartTime());
        }
    }

    TEventsGuard tgAuth(g.MutableReport(), "TAuthRequestProcessor::DoAuthProcess");
    DoAuthProcess(g, authInfo);
}

TAuthRequestProcessor::TAuthRequestProcessor(const IAuthRequestProcessorConfig& config, IReplyContext::TPtr context, IAuthModule::TPtr auth, const IServerBase* server)
    : IRequestProcessor(config, context, server)
    , Auth(auth)
    , AuthConfig(config)
{
    Y_UNUSED(AuthConfig);
}

TString IAuthRequestProcessorConfig::GetAuthModuleName() const {
    return AuthModuleName;
}

IRequestProcessor::TPtr IAuthRequestProcessorConfig::DoConstructProcessor(IReplyContext::TPtr context, const IServerBase* server) const {
    auto authModuleConfig = server->GetAuthModuleInfo(GetAuthModuleName());
    Y_ENSURE_EX(authModuleConfig, yexception() << "Can't find module for " << Name << "/" << GetAuthModuleName());
    auto authModule = authModuleConfig->ConstructAuthModule(server);
    Y_ENSURE_EX(authModule, yexception() << "Can't construct module for " << Name << "/" << GetAuthModuleName());
    return DoConstructAuthProcessor(context, std::move(authModule), server);
}

void IAuthRequestProcessorConfig::DoInit(const TYandexConfig::Section* section) {
    AuthModuleName = section->GetDirectives().Value("AuthModuleName", AuthModuleName);
    AssertCorrectConfig(!!AuthModuleName, "Incorrect 'AuthModuleName' field in configuration '%s'", Name.data());
}

void IAuthRequestProcessorConfig::ToString(IOutputStream& os) const {
    IRequestProcessorConfig::ToString(os);
    os << "AuthModuleName: " << AuthModuleName << Endl;
}
