#include "multi_fetcher.h"

#include "config.h"
#include "token_info.h"

#include <passport/infra/daemons/blackbox/src/misc/db_types.h>
#include <passport/infra/daemons/blackbox/src/misc/exception.h>
#include <passport/infra/daemons/blackbox/src/misc/shards_map.h>

namespace NPassport::NBb {
    TOAuthMultiFetcher::TOAuthMultiFetcher(const TOAuthConfig& config)
        : TOAuthBaseFetcher(config)
    {
    }

    TOAuthMultiFetcher::TDbHandles TOAuthMultiFetcher::SendRequestForTokenAttrs(const TStringBuf uid,
                                                                                const std::optional<TClientId> clientId,
                                                                                const TStringBuf deviceId) const {
        TDbHandles res;
        res.reserve(Config_.OauthShards_->GetShardCount());

        TString query;
        for (NDbPool::TDbPool& db : Config_.OauthShards_->GetShards()) {
            NDbPool::TNonBlockingHandle h(db);
            if (query.empty()) {
                query = BuildShardQuery(uid, clientId, deviceId, h);
            }

            h.SendQuery(query);

            res.push_back(std::move(h));
        }

        return res;
    }

    TOAuthMultiFetcher::TTokens TOAuthMultiFetcher::WaitTokenAttrs(TDbHandles handles) const {
        TTokens res;

        for (NDbPool::TNonBlockingHandle& h : handles) {
            for (const NDbPool::TRow& row : h.WaitResult()->ExctractTable()) {
                ReadTokenAttrRow(row, res);
            }

            h = NDbPool::TNonBlockingHandle();
        }

        return FinishProcessingForTokenAttributes(std::move(res));
    }

    TOAuthMultiFetcher::TTokens TOAuthMultiFetcher::FetchClientAttrs(TTokens tokens) const {
        TClientIdIndex idx = BuildIndex(tokens);

        std::unique_ptr<NDbPool::TResult> result =
            NDbPool::TBlockingHandle(*Config_.OauthCentral_)
                .Query(BuildCentralQuery(idx));

        for (const NDbPool::TRow& row : result->ExctractTable()) {
            TClientId id = ReadClientAttrRow(row, idx, tokens);
            idx[id].WasFound = true;
        }

        return FinishProcessingForClientAttributes(std::move(tokens), idx);
    }

    TString TOAuthMultiFetcher::BuildShardQuery(const TStringBuf uid,
                                                const std::optional<TClientId> clientId,
                                                const TStringBuf deviceId,
                                                NDbPool::TNonBlockingHandle& handle) {
        TString res = NUtils::CreateStr(
            "SELECT a.id, a.type, a.value "
            "FROM token_attributes AS a, token_by_params AS p "
            "WHERE a.id=p.id AND p.uid=",
            uid);

        if (clientId) {
            NUtils::Append(res, " AND p.client_id=", *clientId);
        }

        if (deviceId) {
            NUtils::Append(res, " AND p.device_id='", handle.EscapeQueryParam(deviceId), "'");
        }

        return res;
    }

    TOAuthMultiFetcher::TClientIdIndex TOAuthMultiFetcher::BuildIndex(const TTokens& tokens) {
        TClientIdIndex res;

        for (const auto& [id, token] : tokens) {
            const TString& clientId = token.Info->GetTokenAttr(TOAuthTokenAttr::CLIENT_ID);

            if (clientId.empty()) {
                TLog::Warning() << "OAuthMultiFetcher: there is no client_id for token_id=" << id;
                continue;
            }

            ui32 clid = 0;
            if (!TryIntFromString<10>(clientId, clid)) {
                throw TBlackboxError(TBlackboxError::EType::Unknown)
                    << "Invalid client_id in token_id=" << id << ": " << clientId;
            }

            std::vector<TTokenId>& tokenIds =
                res.emplace(clid, TTokenIdsForClientId{}).first->second.Ids;
            tokenIds.push_back(id);
        }

        return res;
    }

    TString TOAuthMultiFetcher::BuildCentralQuery(const TClientIdIndex& idx) const {
        TString clientIdList;
        for (const auto& [id, info] : idx) {
            NUtils::Append(clientIdList, clientIdList.empty() ? "" : ",", id);
        }

        // processClientAttribute() assumes that colunm#0==type and column#1==value
        TString res = NUtils::CreateStrExt(
            64,
            "SELECT type,value,id FROM client_attributes WHERE id IN(",
            clientIdList,
            ")");

        if (AllClientAttrs_) {
            return res;
        }

        TString clientAttrs;
        for (const TString& a : DEFAULT_CLIENT_ATTRS) {
            NUtils::AppendSeparated(clientAttrs, ',', a);
        }
        for (const TString& a : AddClientAttrs_) {
            NUtils::AppendSeparated(clientAttrs, ',', a);
        }

        NUtils::Append(res, " AND type IN (", clientAttrs, ")");

        return res;
    }

    void TOAuthMultiFetcher::ReadTokenAttrRow(const NDbPool::TRow& row, TTokens& out) const {
        const ui32 tokenId = row[0].As<ui32>();

        auto [it, ok] = out.emplace(tokenId, TToken{});
        if (ok) {
            it->second.Info = std::make_unique<TOAuthTokenInfo>();
        }

        ProcessTokenAttribute(row, *it->second.Info);
    }

    TOAuthMultiFetcher::TTokens TOAuthMultiFetcher::FinishProcessingForTokenAttributes(TTokens tokens) const {
        for (auto& [id, token] : tokens) {
            PostProcessTokenAttributes(*token.Info);
            CheckTokenInfo(*token.Info, token.Error);
        }

        return tokens;
    }

    TOAuthMultiFetcher::TClientId TOAuthMultiFetcher::ReadClientAttrRow(const NDbPool::TRow& row,
                                                                        const TClientIdIndex& idx,
                                                                        TTokens& out) const {
        const TClientId clientId = row[2].As<TClientId>();

        auto tokenIds = idx.find(clientId);
        Y_ENSURE(tokenIds != idx.end(),
                 "got client_id=" << clientId << " in db response, which was not in query");

        for (TTokenId id : tokenIds->second.Ids) {
            auto it = out.find(id);
            Y_ENSURE(it != out.end(),
                     "idx is broken: token info was not found for tokid=" << id);

            ProcessClientAttribute(row, *it->second.Info, it->second.Error);
        }

        return clientId;
    }

    TOAuthMultiFetcher::TTokens TOAuthMultiFetcher::FinishProcessingForClientAttributes(TTokens tokens,
                                                                                        const TClientIdIndex& idx) {
        for (auto& [id, token] : tokens) {
            PostProcessClientAttributes(*token.Info);
        }

        for (const auto& [clientId, tokenIds] : idx) {
            if (tokenIds.WasFound) {
                continue;
            }

            for (TTokenId id : tokenIds.Ids) {
                auto it = tokens.find(id);
                Y_ENSURE(it != tokens.end(),
                         "idx is broken: token info was not found for tokid=" << id);

                it->second.Error.SetError(TOAuthError::ClientNotFound);
            }
        }

        return tokens;
    }

    void TOAuthMultiFetcher::LogOAuth(const TOAuthTokenInfo&, const TString&, const TString&) const {
        // it is not oauth token check - so don't log it
    }

    bool TOAuthMultiFetcher::TTokenIdsForClientId::operator==(const TTokenIdsForClientId& o) const {
        return Ids == o.Ids && WasFound == o.WasFound;
    }

}
