#include "staff_fetcher.h"

#include "runtime_context.h"

#include <passport/infra/daemons/tvmapi/src/exception.h>
#include <passport/infra/daemons/tvmapi/src/utils/utils.h>

#include <passport/infra/libs/cpp/dbpool/result.h>
#include <passport/infra/libs/cpp/dbpool/util.h>
#include <passport/infra/libs/cpp/json/reader.h>
#include <passport/infra/libs/cpp/unistat/builder.h>
#include <passport/infra/libs/cpp/utils/file.h>
#include <passport/infra/libs/cpp/utils/string/coder.h>

#include <contrib/libs/openssl/include/openssl/err.h>

#include <library/cpp/http/simple/http_client.h>

#include <util/string/cast.h>

#include <atomic>

namespace NPassport::NTvm {
    static const TString CANT_PARSE_JSON = "StaffFetcher: can't parse json from staff-api: ";

    TStaffFetcher::TStaffFetcher(const TRuntimeContext& runtime)
        : Runtime_(runtime)
        , UnistatResponseTime_("responsetime.staff", NUnistat::TTimeStat::CreateBoundsFromMaxValue(Runtime_.Config().Staff.QueryTimeout))
    {
        if (!Runtime_.Config().Staff.CachePath.empty()) {
            try {
                ReadCacheFromFile();
                TLog::Info() << "StaffFetcher: loaded from file";
            } catch (const std::exception& e) {
                TLog::Error() << "StaffFetcher: failed to get info from file: " << e.what();
            }
        }

        if (!Cache_.Get()) {
            try {
                RefreshViaHttp();
                TLog::Info() << "StaffFetcher: loaded from HTTP";
            } catch (const std::exception& e) {
                TLog::Error() << "StaffFetcher: failed to get info from Staff: " << e.what();
            }
        }

        if (!Cache_.Get()) {
            Cache_.Set(std::make_shared<TCache>());
        }
    }

    TStaffFetcher::~TStaffFetcher() = default;

    void TStaffFetcher::AddUnistat(NUnistat::TBuilder& builder) const {
        builder.Add(UnistatQueryErrors_);
        builder.Add(UnistatParsingErrors_);
        UnistatResponseTime_.AddUnistat(builder);
    }

    ui64 TStaffFetcher::LoginToUid(const TString& login) const {
        const TCachePtr cache = Cache_.Get();

        auto it = cache->LoginToUid.find(login);
        if (it == cache->LoginToUid.end()) {
            return 0;
        }

        return it->second;
    }

    TString TStaffFetcher::UidToLogin(const ui64 uid) const {
        const TCachePtr cache = Cache_.Get();

        auto it = cache->UidToLogin.find(uid);
        if (it == cache->UidToLogin.end()) {
            return {};
        }

        return it->second;
    }

    TStaffFetcher::TCheckResult TStaffFetcher::CheckSign(ui64 uid, const TStringBuf sign, const TStringBuf rawString) const {
        const TCachePtr cache = Cache_.Get();

        auto it = cache->ByUid.find(uid);
        if (it == cache->ByUid.end() || it->second.empty()) {
            return {
                .Err = NUtils::CreateStr("No one ssh key found for uid ", uid),
            };
        }

        const TCache::TKeys& keys = it->second;
        TString agentSign = NUtils::TSshPublicKey::GetSshAgentSign(sign);

        // Do not hope for staff: it may fail to send fingerprint.
        // So we should correctly return ssh check result
        std::optional<TString> fingerprint =
            agentSign.empty()
                ? CheckSignAsIs(keys, sign, rawString)
                : CheckSignSshAgent(keys, agentSign, rawString);
        const TStringBuf checker = agentSign.empty() ? "as-is" : "sshagent";

        if (fingerprint) {
            TLog::Debug() << "Sign check succeed for uid " << uid
                          << ". Keys count: " << keys.size()
                          << ". checker: '" << checker
                          << "'. fingerprint: '" << *fingerprint << "'";
            return {
                .Fingerprint = *fingerprint,
            };
        }

        TString err = NUtils::CreateStrExt(
            16 * keys.size(),
            "None of the ", keys.size(),
            " sshkeys fit with '", checker, "' checker.",
            " fingerprints: ");

        for (const TKeyInfo& key : keys) {
            NUtils::Append(err, "'", key.Fingerprint, "',");
        }
        err.pop_back();

        return {
            .Err = std::move(err),
        };
    }

    void TStaffFetcher::Run() {
        RefreshViaHttp();
    }

    std::optional<TString> TStaffFetcher::CheckSignAsIs(const TCache::TKeys& keys, const TStringBuf sign, const TStringBuf rawString) {
        TString err;
        for (const TKeyInfo& k : keys) {
            if (k.PublicKey.Verify(sign, rawString, NUtils::TSshPublicKey::EMode::RSA, err) ||
                k.PublicKey.Verify(sign, rawString, NUtils::TSshPublicKey::EMode::RSA_PSS, err))
            {
                return k.Fingerprint;
            }
        }
        if (!err.empty()) {
            TLog::Error() << "Error on sshkey: " << err;
        }

        return {};
    }

    std::optional<TString> TStaffFetcher::CheckSignSshAgent(const TStaffFetcher::TCache::TKeys& keys, const TStringBuf sign, const TStringBuf rawString) {
        TString err;
        for (const TKeyInfo& k : keys) {
            if (k.PublicKey.Verify(sign, rawString, NUtils::TSshPublicKey::EMode::RSA, err)) {
                return k.Fingerprint;
            }
        }
        if (!err.empty()) {
            TLog::Error() << "Error on sshkey: " << err;
        }

        return {};
    }

    void TStaffFetcher::RefreshViaHttp() {
        TLog::Info() << "StaffFetcher: start refresh from staff";
        if (Token_.empty()) {
            Token_ = "OAuth " + Runtime_.GetOAuthToken();
        }

        TKeepAliveHttpClient client(Runtime_.Config().Staff.Host,
                                    Runtime_.Config().Staff.Port,
                                    Runtime_.Config().Staff.QueryTimeout,
                                    Runtime_.Config().Staff.ConnectionTimeout);

        const size_t expectedLoginCount = 100000;
        std::vector<TString> vec;
        vec.reserve(expectedLoginCount / Runtime_.Config().Staff.Limit + 1);

        TCachePtr cache = std::make_shared<TCache>();
        cache->ByUid.reserve(Runtime_.Config().Staff.Limit);
        cache->LoginToUid.reserve(Runtime_.Config().Staff.Limit);
        cache->UidToLogin.reserve(Runtime_.Config().Staff.Limit);

        size_t totalKeys = 0;
        size_t lastEntityId = 0;

        while (true) {
            TResp resp = GetPage(client, lastEntityId);

            TErrorIncrementer err{&UnistatParsingErrors_};

            rapidjson::Document doc;
            Y_ENSURE(NJson::TReader::DocumentAsObject(resp.Body, doc),
                     CANT_PARSE_JSON << " http response is not object. " << resp.Body);

            try {
                const TParsePageResult result = ParsePage(doc, *cache);

                if (result.LoginCount == 0) {
                    TLog::Debug() << "StaffFetcher: Got empty list of keys from Staff"
                                  << ". Took: " << resp.Time
                                  << ". Stop fetching";
                    err.Err = nullptr;
                    break;
                }
                TLog::Debug() << "StaffFetcher: Got " << result.KeyCount << " keys, " << result.LoginCount << " logins from Staff"
                              << ". Took: " << resp.Time
                              << ". Continue fetching";

                totalKeys += result.KeyCount;
                lastEntityId = result.LastId;
            } catch (const std::exception& e) {
                TLog::Error() << "Exception on parsing staff response: " << e.what() << ".body: " << resp.Body;
                throw;
            }

            vec.push_back(std::move(resp.Body));
            err.Err = nullptr;
        }

        TLog::Info() << "StaffFetcher: fetched " << totalKeys << " keys, "
                     << cache->LoginToUid.size() << " logins";
        Cache_.Set(std::move(cache));

        TUtils::WriteJsonArrayToFile(vec, Runtime_.Config().Staff.CachePath);
    }

    static const TString STAFF_QUERY = NUtils::Urlencode("id > ");
    TStaffFetcher::TResp TStaffFetcher::GetPage(TKeepAliveHttpClient& client, ui32 lastEntityId) {
        const TString url = NUtils::CreateStr(
            "/v3/persons"
            "?_query=",
            STAFF_QUERY,
            lastEntityId,
            "&_limit=",
            Runtime_.Config().Staff.Limit,
            "&official.is_dismissed=false",
            "&_sort=id",
            "&_nopage=1",
            "&_fields=keys.key,uid,login,keys.fingerprint,id");

        TString output;
        size_t retries = 0;
        TDuration respTime;
        bool success = TUtils::FetchWithRetries(
            client,
            url,
            UnistatResponseTime_,
            Runtime_.Config().Staff.Retries,
            Token_,
            output,
            retries,
            respTime);
        UnistatQueryErrors_ += retries;
        Y_ENSURE(success, "StaffFetcher: Failed to fetch keys from Staff");

        return {output, respTime};
    }

    TStaffFetcher::TParsePageResult TStaffFetcher::ParsePage(rapidjson::Value& doc, TStaffFetcher::TCache& cache) {
        const rapidjson::Value* jUsersArray = nullptr;
        Y_ENSURE(NJson::TReader::MemberAsArray(doc, "result", jUsersArray),
                 CANT_PARSE_JSON << " page: result.");

        TParsePageResult res;
        for (std::size_t userIdx = 0; userIdx < jUsersArray->Size(); ++userIdx) {
            const rapidjson::Value& jUserObj = (*jUsersArray)[userIdx];

            ui64 id = 0;
            Y_ENSURE(NJson::TReader::MemberAsUInt64(jUserObj, "id", id));
            res.LastId = std::max(res.LastId, id);

            const rapidjson::Value* jKeyArray = nullptr;
            Y_ENSURE(NJson::TReader::MemberAsArray(jUserObj, "keys", jKeyArray),
                     CANT_PARSE_JSON << " page: result/keys.");

            TString uidStr;
            Y_ENSURE(NJson::TReader::MemberAsString(jUserObj, "uid", uidStr),
                     CANT_PARSE_JSON << " page: result/uid.");

            ui64 uid = IntFromString<ui64, 10>(uidStr);
            TString loginStr;
            if (NJson::TReader::MemberAsString(jUserObj, "login", loginStr)) {
                cache.LoginToUid.emplace(loginStr, uid);
                cache.UidToLogin.emplace(uid, loginStr);
            }

            TCache::TKeys& keys = cache.ByUid.emplace(uid, TCache::TKeys()).first->second;

            for (std::size_t keyIdx = 0; keyIdx < jKeyArray->Size(); ++keyIdx) {
                const rapidjson::Value& jKeyObj = (*jKeyArray)[keyIdx];

                TStringBuf keyStr;
                Y_ENSURE(NJson::TReader::MemberAsString(jKeyObj, "key", keyStr),
                         CANT_PARSE_JSON << " page: result/keys/key.");

                TString fingerprint;
                NJson::TReader::MemberAsString(jKeyObj, "fingerprint", fingerprint); // TODO Y_ENSURE

                try {
                    NUtils::TSshPublicKey key(keyStr);
                    keys.push_back(TKeyInfo{std::move(key), std::move(fingerprint)});
                    ++res.KeyCount;
                } catch (const NUtils::TSshPublicKey::TUnsupportedException&) {
                } catch (const NUtils::TSshPublicKey::TMalformedException& e) {
                    TLog::Debug() << "Ssh malformed exception: " << e.what() << ". for uid " << uid;
                } catch (const std::exception& e) {
                    TLog::Warning() << "Ssh exception: " << e.what() << ". for uid " << uid;
                }
            }

            ++res.LoginCount;
        }

        return res;
    }

    void TStaffFetcher::ReadCacheFromFile() {
        rapidjson::Document doc;
        if (!NJson::TReader::DocumentAsArray(NUtils::ReadFile(Runtime_.Config().Staff.CachePath), doc)) {
            throw yexception() << "StaffFetcher[ERROR]: failed to read file cache";
        }

        TCachePtr cache = std::make_shared<TCache>();
        cache->ByUid.reserve(10000); // 3500 in production now
        cache->LoginToUid.reserve(10000);
        cache->UidToLogin.reserve(10000);
        size_t totalKeys = 0;
        for (rapidjson::SizeType idx = 0; idx < doc.Size(); ++idx) {
            try {
                totalKeys += ParsePage(doc[idx], *cache).KeyCount;
            } catch (const std::exception& e) {
                TLog::Error() << "Exception on parsing staff cache: " << e.what();
                throw;
            }
        }
        TLog::Info() << "StaffFetcher: fetched " << totalKeys << " keys, "
                     << cache->LoginToUid.size() << " logins."
                     << " From " << Runtime_.Config().Staff.CachePath;

        Cache_.Set(std::move(cache));
    }

}
