#include "account_checker.h"

#include <drive/backend/abstract/notifier.h>

#include <drive/library/cpp/threading/future.h>

TAccountChecker::TAccountChecker(const TAccountCheckerConfig& config, const NDrive::IServer* server)
    : Config(config)
    , Server(server)
{
    {
        TString endpointUrl = "https://" + config.GetHost() + ":443";
        if (config.GetRoute().StartsWith("/")) {
            endpointUrl += config.GetRoute();
        } else {
            endpointUrl += "/" + config.GetRoute();
        }
        TAtomicSharedPtr<NTvmAuth::TTvmClient> tvmClient;
        if (auto tvmId = config.GetSelfTvmId()) {
            tvmClient = server->GetTvmClient(tvmId);
            Y_ENSURE(tvmClient, "cannot get TVM client " << tvmId);
        }
        BlackboxClient.Reset(new NDrive::TBlackboxClient(endpointUrl, tvmClient));
    }
    {
        TString dbfieldsStr = "account_info.reg_date.uid";
        for (auto&& element : Config.GetHighPrioritySuids()) {
            dbfieldsStr += ",subscription.suid." + element;
        }
        for (auto&& element : Config.GetLowPrioritySuids()) {
            dbfieldsStr += ",subscription.suid." + element;
        }
        BlackboxClient->SetExternalOptions(NBlackbox2::TOptions(NBlackbox2::TOption("dbfields", dbfieldsStr)));
    }
}

bool TAccountChecker::IsPassed(const ui64 passportUid, const TSet<TString>& countries, const TString& notifierName) const {
    auto scoringLevel = GetScoringLevel(countries);

    NUtil::THttpReply reply;
    bool isAccountOldEnough = false;
    bool isEnoughSuids = false;

    try {
        auto infoResponseFuture = BlackboxClient->UidInfoRequest(ToString(passportUid), "8.8.8.8");
        infoResponseFuture.Wait();
        if (!infoResponseFuture.Initialized() || !infoResponseFuture.HasValue()) {
            if (notifierName) {
                NDrive::INotifier::Notify(Server->GetNotifier(notifierName), "empty blackbox response for userinfo on uid " + ToString(passportUid));
            }
            ERROR_LOG << "empty blackbox response for userinfo on uid " << passportUid << Endl;
            return scoringLevel == EScoringLevel::NoScoring;
        }
        if (infoResponseFuture.HasException()) {
            auto excMessage = NThreading::GetExceptionMessage(infoResponseFuture);
            if (notifierName) {
                NDrive::INotifier::Notify(Server->GetNotifier(notifierName), "exception on blackbox quiery for userinfo on uid " + ToString(passportUid) + " " + excMessage);
            }
            ERROR_LOG << "exception on blackbox quiery for userinfo on uid " << passportUid << " " << excMessage << Endl;
            return scoringLevel == EScoringLevel::NoScoring;
        }
        const auto& infoResponse = infoResponseFuture.GetValue();
        if (!infoResponse) {
            if (notifierName) {
                NDrive::INotifier::Notify(Server->GetNotifier(notifierName), "bad blackbox reply code for userinfo on uid " + ToString(passportUid));
            }
            ERROR_LOG << "bad blackbox reply code for userinfo on uid " << passportUid << Endl;
            return scoringLevel == EScoringLevel::NoScoring;
        }

        NBlackbox2::TDBFields dbFields(infoResponse.Get());
        isAccountOldEnough = IsOldEnough(passportUid, dbFields);
        isEnoughSuids = IsEnoughSuids(dbFields);
    } catch (const std::exception& e) {
        NDrive::INotifier::Notify(Server->GetNotifier(notifierName), "generic exception: " + FormatExc(e));
        return scoringLevel == EScoringLevel::NoScoring;
    }
    if (scoringLevel == EScoringLevel::NoScoring) {
        return true;
    }

    if (scoringLevel == EScoringLevel::Full) {
        return isAccountOldEnough & isEnoughSuids;
    }
    return isAccountOldEnough | isEnoughSuids;
}

TAccountChecker::EScoringLevel TAccountChecker::GetScoringLevel(const TSet<TString>& countries) const {
    bool onlyNoScoringCountries = true;
    for (auto&& c : countries) {
        if (Config.GetFullScoringCountries().contains(ToUpperUTF8(c))) {
            return EScoringLevel::Full;
        }
        if (!Config.GetNoScoringCountries().contains(ToUpperUTF8(c))) {
            onlyNoScoringCountries = false;
        }
    }
    return onlyNoScoringCountries ? EScoringLevel::NoScoring : EScoringLevel::Weak;
}

bool TAccountChecker::IsEnoughSuids(const NBlackbox2::TDBFields& dbFields) const {
    ui32 score = 0;
    for (auto&& element : Config.GetHighPrioritySuids()) {
        auto dbfield = "subscription.suid." + element;
        if (dbFields.Get(dbfield) == "1") {
            score += 2;
        }
    }
    for (auto&& element : Config.GetLowPrioritySuids()) {
        auto dbfield = "subscription.suid." + element;
        if (dbFields.Get(dbfield) == "1") {
            score += 1;
        }
    }
    return score >= Config.GetMinSuidWeight();
}

bool TAccountChecker::IsOldEnough(const ui64 passportUid, const NBlackbox2::TDBFields& dbFields) const {
    auto dateStr = dbFields.Get("account_info.reg_date.uid");
    if (dateStr.size() < 11) {
        return false;
    }
    dateStr[10] = 'T';
    dateStr += "Z";
    TInstant creationDate;
    if (!TInstant::TryParseIso8601(dateStr, creationDate)) {
        return false;
    }
    CachedCreationDate[passportUid] = creationDate;
    return creationDate + Config.GetMinAccountAge() < Now();
}

TInstant TAccountChecker::GetCachedCreationDate(const ui64 passportUid) const {
    auto it = CachedCreationDate.find(passportUid);
    if (it == CachedCreationDate.end()) {
        return TInstant::Zero();
    }
    return it->second;
}
