#include "tvm_grants.h"

#include <passport/infra/daemons/blackbox/src/ip/ipacl_map.h>
#include <passport/infra/daemons/blackbox/src/misc/exception.h>

#include <passport/infra/libs/cpp/json/reader.h>
#include <passport/infra/libs/cpp/utils/file.h>
#include <passport/infra/libs/cpp/utils/log/global.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <regex>

namespace NPassport::NBb {
    static const std::regex ILLEGAL_DBFIELD_REGEX("[^a-zA-Z._0-9-]");

    TTvmGrants::TTvmGrants(const TString& path, TDuration period)
        : Path_(path)
    {
        FileLoader_ = std::make_unique<NUtils::TFileLoader>(
            path,
            [this](const TStringBuf fileBody, time_t mtime) { this->Update(fileBody, mtime); },
            period);
    }

    TTvmGrants::~TTvmGrants() = default;

    TTvmGrants::TFindResult TTvmGrants::FindConsumer(const TClientId clientId, const NUtils::TIpAddr& ip) const {
        const std::shared_ptr<TStorage> consMap = ConsumerMap_.Get();
        auto consIt = consMap->Map.find(clientId);
        if (consIt == consMap->Map.end()) {
            return {
                .Consumer = {},
                .Err = NUtils::CreateStr("tvm_id=", clientId, " is not allowed"),
            };
        }

        TString err;
        if (!consIt->second.IpMap->Find(ip)) {
            // do not throw exception: we detected consumer, so we can write name and tvm_id to logs and yasm
            err = NUtils::CreateStr("IP=", ip.ToString(), " is not allowed for tvm_id=", clientId);
        }
        return {consIt->second.Consumer, err};
    }

    time_t TTvmGrants::GetMTime() const {
        std::shared_ptr<TStorage> cache = ConsumerMap_.Get();
        return cache ? cache->Mtime : 0;
    }

    ui64 TTvmGrants::GetParsingErrors() const {
        return ParsingErrors_.GetValue();
    }

    std::shared_ptr<TTvmGrants::TStorage> TTvmGrants::Parse(const TStringBuf fileBody, const time_t mtime) {
        rapidjson::Document doc;
        if (!NJson::TReader::DocumentAsObject(fileBody, doc)) {
            throw yexception() << "Tvm grants: broken json: " << fileBody;
        }

        std::shared_ptr<TStorage> grants = std::make_shared<TStorage>();
        grants->Mtime = mtime;
        grants->Map.reserve(doc.MemberCount());
        for (auto consumerIt = doc.MemberBegin(); consumerIt != doc.MemberEnd(); ++consumerIt) {
            TString consumerName(consumerIt->name.GetString(), consumerIt->name.GetStringLength());

            rapidjson::Value& consumerJson = consumerIt->value;
            if (!consumerJson.IsObject()) {
                TLog::Error("Tvm grants: Failed to load grants for consumer (object): %s", consumerName.c_str());
                continue;
            }

            const rapidjson::Value* clientJson = nullptr;
            if (!NJson::TReader::MemberAsObject(consumerJson, "client", clientJson)) {
                TLog::Error("Tvm grants: Failed to load grants for consumer ('client'): %s", consumerName.c_str());
                continue;
            }

            ui32 clientId = 0;
            if (!NJson::TReader::MemberAsUInt(*clientJson, "client_id", clientId)) {
                TLog::Error("Tvm grants: Failed to load grants for consumer ('client_id'): %s", consumerName.c_str());
                continue;
            }

            const rapidjson::Value* grantsJson = nullptr;
            if (!NJson::TReader::MemberAsObject(consumerJson, "grants", grantsJson)) {
                TLog::Error("Tvm grants: Failed to load grants for consumer ('grants'): %s", consumerName.c_str());
                continue;
            }

            std::shared_ptr<TConsumer> consumer = ParseConsumer(*grantsJson, consumerName, clientId);
            std::shared_ptr<TIpAclMap<TConsumer>> ipacl = ParseIp(consumer, consumerJson);
            if (!ipacl) {
                TLog::Error("Tvm grants: Failed to create ipacl for consumer: %s", consumerName.c_str());
                continue;
            }

            grants->Map.emplace(clientId, TConsumerInfo{std::move(ipacl), std::move(consumer)});
        }

        return grants;
    }

    void TTvmGrants::Update(const TStringBuf fileBody, const time_t mtime) {
        try {
            ConsumerMap_.Set(Parse(fileBody, mtime));
            TLog::Info("Tvm grants: refresh() succed: %s", Path_.c_str());
        } catch (const std::exception&) {
            ++ParsingErrors_;
            throw;
        }
    }

    std::shared_ptr<TConsumer> TTvmGrants::ParseConsumer(const rapidjson::Value& obj,
                                                         const TString& name,
                                                         NTvmAuth::TTvmId clientId) {
        std::shared_ptr<TConsumer> consumer = std::make_shared<TConsumer>(clientId);
        consumer->SetName(name);

        // Add cache options
        for (const TString& method : {"userinfo", "oauth", "sessionid", "hosted_domains"}) {
            const rapidjson::Value* arr = nullptr;
            if (NJson::TReader::MemberAsArray(obj, ("__allow_cache_for_" + method).c_str(), arr)) {
                consumer->AddCachableMethod(method);
                TLog::Info() << "Tvm grants: enable cache for method=" << method
                             << ": " << consumer->GetName();
            }
        }

        // login/password checking policy
        consumer->SetCaptchaCap(GetBool(obj, "can_captcha"));
        consumer->SetDelayCap(GetBool(obj, "can_delay"));
        consumer->SetAllowPinTest(GetBool(obj, "allow_pin_test"));

        const rapidjson::Value* arr = nullptr;
        if (NJson::TReader::MemberAsArray(obj, "allow_login", arr)) {
            consumer->SetAllow(TBlackboxMethods::Login, true);
            TString allowLogin = GetString(*arr);
            if (!allowLogin.empty()) {
                consumer->SetLoginSafety(allowLogin == "weak" ? TConsumer::login_Weak : TConsumer::login_Strict);
            }
        }

        // methods
        consumer->SetAllow(TBlackboxMethods::Cookie, GetBool(obj, "allow_parse_cookie"));
        consumer->SetAllow(TBlackboxMethods::OAuth, GetBool(obj, "allow_oauth"));
        consumer->SetAllow(TBlackboxMethods::UserInfo, GetBool(obj, "allow_user_info"));
        consumer->SetAllow(TBlackboxMethods::CheckIp, GetBool(obj, "allow_check_ip"));
        consumer->SetAllow(TBlackboxMethods::LoginOccupation, GetBool(obj, "allow_login_occupation"));
        consumer->SetAllow(TBlackboxMethods::PwdHistory, GetBool(obj, "allow_pwd_history"));
        consumer->SetAllow(TBlackboxMethods::CreateSession, GetBool(obj, "allow_create_session"));
        consumer->SetAllow(TBlackboxMethods::HostedDomains, GetBool(obj, "allow_hosted_domains"));
        consumer->SetAllow(TBlackboxMethods::FindPddAccounts, GetBool(obj, "allow_find_pdd_accounts"));
        consumer->SetAllow(TBlackboxMethods::LCookie, GetBool(obj, "allow_l_cookie"));
        consumer->SetAllow(TBlackboxMethods::PhoneBindings, GetBool(obj, "allow_phone_bindings"));
        consumer->SetAllow(TBlackboxMethods::PhoneOperations, GetBool(obj, "allow_phone_operations"));
        consumer->SetAllow(TBlackboxMethods::TestPwdHashes, GetBool(obj, "allow_test_pwd_hashes"));
        consumer->SetAllow(TBlackboxMethods::GetTrack, GetBool(obj, "allow_get_track"));
        consumer->SetAllow(TBlackboxMethods::ProveKeyDiag, GetBool(obj, "allow_prove_key_diag"));
        consumer->SetAllow(TBlackboxMethods::EditTotp, GetBool(obj, "allow_edit_totp"));
        consumer->SetAllow(TBlackboxMethods::EmailBindings, GetBool(obj, "allow_email_bindings"));
        consumer->SetAllow(TBlackboxMethods::GetAllTracks, GetBool(obj, "allow_get_all_tracks"));
        consumer->SetAllow(TBlackboxMethods::YakeyBackup, GetBool(obj, "allow_yakey_backup"));
        consumer->SetAllow(TBlackboxMethods::CreatePwdHash, GetBool(obj, "allow_create_pwd_hash"));
        consumer->SetAllow(TBlackboxMethods::DeletionOperations, GetBool(obj, "allow_deletion_operations"));
        consumer->SetAllow(TBlackboxMethods::CreateOAuthToken, GetBool(obj, "allow_create_oauth_token"));
        consumer->SetAllow(TBlackboxMethods::GetRecoveryKeys, GetBool(obj, "allow_get_recovery_keys"));
        consumer->SetAllow(TBlackboxMethods::CheckRfcTotp, GetBool(obj, "allow_check_rfc_totp"));
        consumer->SetAllow(TBlackboxMethods::UserTicket, GetBool(obj, "allow_user_ticket"));
        consumer->SetAllow(TBlackboxMethods::CheckHasPlus, GetBool(obj, "allow_check_has_plus"));
        consumer->SetAllow(TBlackboxMethods::GetDevicePublicKey, GetBool(obj, "allow_get_device_public_key"));
        consumer->SetAllow(TBlackboxMethods::FamilyInfo, GetBool(obj, "allow_family_info"));
        consumer->SetAllow(TBlackboxMethods::FindByPhoneNumbers, GetBool(obj, "allow_find_by_phone_numbers"));
        consumer->SetAllow(TBlackboxMethods::GeneratePublicId, GetBool(obj, "allow_generate_public_id"));
        consumer->SetAllow(TBlackboxMethods::GetMaxUid, GetBool(obj, "allow_get_max_uid"));
        consumer->SetAllow(TBlackboxMethods::WebauthnCredentials, GetBool(obj, "allow_method_webauthn_credentials"));
        consumer->SetAllow(TBlackboxMethods::GetOAuthTokens, GetBool(obj, "allow_get_oauth_tokens"));

        // flags
        consumer->SetAllow(TBlackboxFlags::LoginByUid, GetBool(obj, "allow_login_by_uid"));
        consumer->SetAllow(TBlackboxFlags::OAuthAttributes, GetBool(obj, "allow_oauth_attributes"));
        consumer->SetAllow(TBlackboxFlags::ResignCookie, GetBool(obj, "allow_resign_cookie"));
        consumer->SetAllow(TBlackboxFlags::CreateStressSession, GetBool(obj, "allow_create_stress_session"));
        consumer->SetAllow(TBlackboxFlags::FindByPhoneAlias, GetBool(obj, "allow_find_by_phone_alias"));
        consumer->SetAllow(TBlackboxFlags::GetHiddenAliases, GetBool(obj, "allow_to_get_hidden_aliases"));
        consumer->SetAllow(TBlackboxFlags::GetPublicName, GetBool(obj, "allow_get_public_name"));
        consumer->SetAllow(TBlackboxFlags::ResignForDomains, GetBool(obj, "allow_resign_for_domains"));
        consumer->SetAllow(TBlackboxFlags::CreateGuard, GetBool(obj, "allow_create_guard"));
        consumer->SetAllow(TBlackboxFlags::GetBillingFeatures, GetBool(obj, "allow_get_billing_features"));
        consumer->SetAllow(TBlackboxFlags::CheckDeviceSignatureWithPublicKey, GetBool(obj, "allow_check_device_signature_with_public_key"));
        consumer->SetAllow(TBlackboxFlags::CheckDeviceSignatureWithDebugMode, GetBool(obj, "allow_check_device_signature_with_debug_mode"));
        consumer->SetAllow(TBlackboxFlags::FamilyInfoPlace, GetBool(obj, "allow_family_info_get_place"));
        consumer->SetAllow(TBlackboxFlags::ForceShowMailSubscription, GetBool(obj, "allow_force_show_mail_subscription"));
        consumer->SetAllow(TBlackboxFlags::GetWebauthnCredentials, GetBool(obj, "allow_get_webauthn_credentials"));
        consumer->SetAllow(TBlackboxFlags::ScholarLogin, GetBool(obj, "allow_scholar_login"));
        consumer->SetAllow(TBlackboxFlags::ScholarSession, GetBool(obj, "allow_scholar_session"));
        consumer->SetAllow(TBlackboxFlags::FederalAccounts, GetBool(obj, "allow_federal_accounts"));
        consumer->SetAllow(TBlackboxFlags::FullInfo, GetBool(obj, "allow_fullinfo"));

        // Allowed DB fields, attributes, phone attributes
        ParseAttrs(*consumer, obj, "no_cred", TConsumer::ERank::NoCred);
        ParseAttrs(*consumer, obj, "has_cred", TConsumer::ERank::HasCred);

        // Allowed sign/checksign signspaces
        if (NJson::TReader::MemberAsArray(obj, "allow_sign", arr)) {
            consumer->SetAllow(TBlackboxMethods::Sign, true);
            for (auto it = arr->Begin(); it != arr->End(); ++it) {
                if (!it->IsString()) {
                    TLog::Error("Tvm grants: sign_space name can't be not string: %s.",
                                consumer->GetName().c_str());
                    continue;
                }
                consumer->AddSignSignspace(TString(it->GetString(), it->GetStringLength()));
            }
        }

        if (NJson::TReader::MemberAsArray(obj, "allow_check_sign", arr)) {
            consumer->SetAllow(TBlackboxMethods::CheckSign, true);
            for (auto it = arr->Begin(); it != arr->End(); ++it) {
                if (!it->IsString()) {
                    TLog::Error("Tvm grants: sign_space name can't be not string: %s.",
                                consumer->GetName().c_str());
                    continue;
                }
                consumer->AddChecksignSignspace(TString(it->GetString(), it->GetStringLength()));
            }
        }

        if (NJson::TReader::MemberAsArray(obj, "allow_check_device_signature", arr)) {
            consumer->SetAllow(TBlackboxMethods::CheckDeviceSignature, true);
            for (auto it = arr->Begin(); it != arr->End(); ++it) {
                if (!it->IsString()) {
                    TLog::Error("Tvm grants: allow_check_device_signature name can't be not string: %s.",
                                consumer->GetName().c_str());
                    continue;
                }
                consumer->AddCheckDeviceSignatureSignspace(TString(it->GetString(), it->GetStringLength()));
            }
        }

        if (NJson::TReader::MemberAsArray(obj, "allowed_partitions", arr)) {
            for (auto it = arr->Begin(); it != arr->End(); ++it) {
                if (!it->IsString()) {
                    TLog::Error("Tvm grants: allowed_partitions name should be string: %s.",
                                consumer->GetName().c_str());
                    continue;
                }
                consumer->AddPartition(TString(it->GetString(), it->GetStringLength()));
            }
        }

        return consumer;
    }

    bool TTvmGrants::GetBool(const rapidjson::Value& obj, const TString& key) {
        const rapidjson::Value* arr = nullptr;
        if (!NJson::TReader::MemberAsArray(obj, key.c_str(), arr)) {
            return false;
        }

        if (arr->Size() != 1 || !arr->Begin()->IsString()) {
            return false;
        }

        return arr->Begin()->GetStringLength() > 0;
    }

    TString TTvmGrants::GetString(const rapidjson::Value& arr) {
        if (arr.Size() != 1 || !arr.Begin()->IsString()) {
            return {};
        }

        const auto& str = arr.Begin();
        return TString(str->GetString(), str->GetStringLength());
    }

    void TTvmGrants::ParseAttrs(TConsumer& consumer, const rapidjson::Value& obj, const char* section, TConsumer::ERank rank) {
        if (!obj.HasMember(section)) {
            return;
        }

        const rapidjson::Value* arr = nullptr;
        if (!NJson::TReader::MemberAsArray(obj, section, arr)) {
            TLog::Error("Tvm grants: attributes are not array: %s",
                        consumer.GetName().c_str());
            return;
        }

        for (auto it = arr->Begin(); it != arr->End(); ++it) {
            if (!it->IsString()) {
                TLog::Error("Tvm grants: attribute grant can't be not string: %s. Section: %s",
                            consumer.GetName().c_str(),
                            section);
                continue;
            }

            TString value(it->GetString(), it->GetStringLength());
            const auto colPos = value.find(':');
            if (colPos == TString::npos) {
                LogAttrError(consumer, section, value, "invalid attribute grant string");
                continue;
            }

            if (value.compare(0, colPos, "attr") == 0) {
                value.erase(0, colPos + 1);
                consumer.AddAttr(value, rank);
            } else if (value.compare(0, colPos, "dbfield") == 0) {
                value.erase(0, colPos + 1);
                if (std::regex_search(value.cbegin(), value.cend(), ILLEGAL_DBFIELD_REGEX)) {
                    LogAttrError(consumer, section, value, "unknown attribute grant category");
                    continue;
                }
                consumer.AddField(value, rank);
            } else if (value.compare(0, colPos, "phone_attr") == 0) {
                value.erase(0, colPos + 1);
                consumer.AddPhoneAttr(value, rank);
            } else {
                LogAttrError(consumer, section, value, "unknown attribute grant category");
            }
        }
    }

    void TTvmGrants::LogAttrError(const TConsumer& consumer, const char* section, const TString& value, const char* msg) {
        TLog::Error("Tvm grants: %s : %s. Value: %s. Section: %s",
                    msg,
                    consumer.GetName().c_str(),
                    value.c_str(),
                    section);
    }

    std::shared_ptr<TIpAclMap<TConsumer>> TTvmGrants::ParseIp(std::shared_ptr<TConsumer> consumer, rapidjson::Value& doc) {
        const rapidjson::Value* networks = nullptr;
        if (!NJson::TReader::MemberAsArray(doc, "networks", networks)) {
            TLog::Error("Tvm grants: networks is not array: %s", consumer->GetName().c_str());
            return {};
        }

        std::shared_ptr<TIpAclMap<TConsumer>> acl = std::make_shared<TIpAclMap<TConsumer>>();
        // Now loop through IP addresses and/or ranges to which this grants set belongs
        for (auto ipit = networks->Begin(); ipit != networks->End(); ++ipit) {
            if (ipit->GetStringLength() == 0) {
                TLog::Error("Tvm grants: ip is empty; ignoring this range. %s", consumer->GetName().c_str());
                continue;
            }

            try {
                TStringBuf val(ipit->GetString(), ipit->GetStringLength());
                acl->ParseAndAddEntry(val, consumer);
            } catch (const std::exception& e) {
                TLog::Error("Tvm grants: error: %s in consumer %s; ignoring this range", e.what(), consumer->GetName().c_str());
            } catch (...) {
                TLog::Error("Tvm grants: unknown error in consumer %s; ignoring this range", consumer->GetName().c_str());
            }
        }

        return acl;
    }
}
