#include "tvm_service.h"

#include <library/cpp/logger/global/global.h>

#include <util/folder/path.h>

namespace {
    const TString LogPrefix = "TvmService: ";

    class TTvmLogger: public NTvmAuth::ILogger {
    private:
        void Log(int lvl, const TString& msg) override {
            TLoggerOperator<TGlobalLog>::Log() << static_cast<ELogPriority>(lvl) << "TvmInternal: " << msg << Endl;
        }
    };

    std::unique_ptr<NTvmAuth::TTvmClient> BuildTvmClient(const NTravelProto::NAppConfig::TTvmServiceConfig& config) {
        if (!config.GetEnabled()) {
            return {};
        }

        TFsPath(config.GetCacheDir()).MkDirs();

        NTvmAuth::NTvmApi::TClientSettings settings;
        settings.SetSelfTvmId(config.GetSelfClientId());
        settings.SetDiskCacheDir(config.GetCacheDir());
        if (config.AllowedSourceClientIdSize()) {
            settings.EnableServiceTicketChecking();
        }
        if (config.DestinationClientIdAliasesSize()) {
            Y_ENSURE(config.HasTvmSecret(), "TvmSecret is required if DestinationClientIdAliases is set");
            NTvmAuth::NTvmApi::TClientSettings::TDstMap dstMap{};
            for (const auto& alias: config.GetDestinationClientIdAliases()) {
                dstMap.emplace(alias.GetName(), alias.GetClientId());
            }
            settings.EnableServiceTicketsFetchOptions(config.GetTvmSecret(), std::move(dstMap));
        }
        return std::make_unique<NTvmAuth::TTvmClient>(settings, MakeIntrusive<TTvmLogger>());
    }
}

namespace NTravel::NTvm {
    TTvmService::TTvmService(const NTravelProto::NAppConfig::TTvmServiceConfig& config)
        : Enabled(config.GetEnabled())
        , TvmClient(BuildTvmClient(config))
        , AllowedSources(config.GetAllowedSourceClientId().begin(), config.GetAllowedSourceClientId().end())
        , Counters(TvmClient)
    {
    }

    void TTvmService::RegisterCounters(NMonitor::TCounterSource& counters) const {
        if (!Enabled) {
            return;
        }

        counters.RegisterSource(&Counters, "TvmService");
    }

    bool TTvmService::IsAllowedServiceTicket(const TStringBuf& ticket) const {
        if (!Enabled) {
            return true;
        }

        Counters.NRequests.Inc();
        if (IsAllowedServiceTicketImpl(ticket)) {
            Counters.NAllowedTickets.Inc();
            return true;
        }
        Counters.NRefusedTickets.Inc();
        return false;
    }

    bool TTvmService::IsAllowedServiceTicketImpl(const TStringBuf& ticket) const {
        auto serviceTicket = TvmClient->CheckServiceTicket(ticket);

        if (!serviceTicket) {
            WARNING_LOG << LogPrefix << "ServiceTicket is not allowed with status: " << serviceTicket.GetStatus() << Endl;
            return false;
        }

        if (!AllowedSources.contains(serviceTicket.GetSrc())) {
            WARNING_LOG << LogPrefix << "ServiceTicket is not allowed by source client id: " << serviceTicket.GetSrc() << Endl;
            return false;
        }

        return true;
    }

    TString TTvmService::GetServiceTicketFor(const TStringBuf& dstAlias) const {
        Y_ENSURE(Enabled, "TvmService is not enabled in config");
        try {
            auto serviceTicket = TvmClient->GetServiceTicketFor(NTvmAuth::NTvmApi::TClientSettings::TAlias(dstAlias));
            Counters.NFetchedServiceTicket.Inc();
            return serviceTicket;
        } catch (...) {
            Counters.NFailedToFetchServiceTicket.Inc();
            ERROR_LOG << LogPrefix << "Failed to get service ticket for " << dstAlias << ": " << CurrentExceptionMessage() << Endl;
            throw;
        }
    }

    bool TTvmService::IsEnabled() const {
        return Enabled;
    }

    TTvmService::TCounters::TCounters(const std::unique_ptr<NTvmAuth::TTvmClient>& tvmClient)
        : TvmClient(tvmClient)
    {
    }

    void TTvmService::TCounters::QueryCounters(NMonitor::TCounterTable* ct) const {
        if (!TvmClient) {
            return;
        }

        StatusOk = 0;
        StatusWarning = 0;
        StatusError = 0;
        StatusUnknown = 0;

        auto status = TvmClient->GetStatus();
        switch (TvmClient->GetStatus().GetCode()) {
            case NTvmAuth::TClientStatus::Ok:
                StatusOk = 1;
                break;
            case NTvmAuth::TClientStatus::Warning:
                StatusWarning = 1;
                break;
            case NTvmAuth::TClientStatus::Error:
                StatusError = 1;
                break;
            default:
                ERROR_LOG << "Unknown tvm updater status: " << status << Endl;
                StatusUnknown = 1;
                break;
        }

        auto now = Now();
        PublicKeysAgeSec = (now - TvmClient->GetUpdateTimeOfPublicKeys()).Seconds();
        PublicKeysTimeToInvalidationSec = (TvmClient->GetInvalidationTimeOfPublicKeys() - now).Seconds();

        ct->insert(MAKE_COUNTER_PAIR(StatusOk));
        ct->insert(MAKE_COUNTER_PAIR(StatusWarning));
        ct->insert(MAKE_COUNTER_PAIR(StatusError));
        ct->insert(MAKE_COUNTER_PAIR(StatusUnknown));

        ct->insert(MAKE_COUNTER_PAIR(PublicKeysAgeSec));
        ct->insert(MAKE_COUNTER_PAIR(PublicKeysTimeToInvalidationSec));

        ct->insert(MAKE_COUNTER_PAIR(NRequests));
        ct->insert(MAKE_COUNTER_PAIR(NAllowedTickets));
        ct->insert(MAKE_COUNTER_PAIR(NRefusedTickets));

        ct->insert(MAKE_COUNTER_PAIR(NFetchedServiceTicket));
        ct->insert(MAKE_COUNTER_PAIR(NFailedToFetchServiceTicket));
    }

}
