#include "dbfetcher.h"

#include "public_keys_builder.h"

#include <passport/infra/daemons/tvmapi/src/proto/disk_cache.pb.h>
#include <passport/infra/daemons/tvmapi/src/utils/utils.h>

#include <passport/infra/libs/cpp/dbpool/db_pool.h>
#include <passport/infra/libs/cpp/dbpool/result.h>
#include <passport/infra/libs/cpp/dbpool/util.h>
#include <passport/infra/libs/cpp/dbpool/value.h>
#include <passport/infra/libs/cpp/json/writer.h>
#include <passport/infra/libs/cpp/tvm/common/private_key.h>
#include <passport/infra/libs/cpp/unistat/builder.h>
#include <passport/infra/libs/cpp/utils/file.h>
#include <passport/infra/libs/cpp/utils/crypto/hash.h>
#include <passport/infra/libs/cpp/utils/log/global.h>
#include <passport/infra/libs/cpp/utils/string/coder.h>
#include <passport/infra/libs/cpp/utils/string/split.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <util/stream/file.h>
#include <util/system/fs.h>

namespace NPassport::NTvm {
    void TCryptoKey::ThrowIfInvalid(NTvmAuth::TTvmId clientId) const {
        try {
            Y_ENSURE(PublicKey, "public key is empty");
            Y_ENSURE(PrivateKey, "private key is empty");

            NTvmAuth::NRw::TRwPublicKey pub(PublicKey);
            NTvmAuth::NRw::TRwPrivateKey priv(PrivateKey, Id);

            const TStringBuf data = "foobar";
            Y_ENSURE(pub.CheckSign(data, priv.SignTicket(data)), "public/private key mismatch");
        } catch (const std::exception& e) {
            ythrow yexception() << "DbFetcher: failed to check Rabin-Williams key: " << e.what()
                                << ". key_id=" << Id
                                << ". client_id=" << clientId;
        }
    }

    TDbCacheBuilder::TDbCacheBuilder(const NTvmCommon::TDecryptor& decr, NUnistat::TSignalDiff<>& errors, size_t minKeyCount)
        : Decr_(decr)
        , MinKeyCount_(minKeyCount)
        , Errors_(errors)
    {
        Clients_.reserve(13 * 10000); // attr count * client count
        Keys_.reserve(2 * 10000 * 2); // attr count * client count * key count per client
    }

    void TDbCacheBuilder::AddClientAttr(TClientId id, int type, const TString& val) {
        auto it = Clients_.find(id);
        const TClientPtr& currentClient =
            it == Clients_.end()
                ? Clients_.emplace(id, std::make_shared<TClient>(id)).first->second
                : it->second;

        try {
            // Get attribute
            const int CLIENT_NAME_ = 1;
            const int CLIENT_SECRET_ = 5;
            const int KEY_IDS_ = 6;
            const int DELETED_TIME_ = 18;
            const int CLIENT_OLD_SECRET_ = 19;
            const int ABC_SERVICE_ID_ = 21;

            switch (type) {
                case CLIENT_NAME_:
                    FillClientName(*currentClient, val);
                    break;
                case CLIENT_SECRET_:
                    FillClientSecret(*currentClient, val, Decr_);
                    break;
                case KEY_IDS_:
                    FillKeyIds(*currentClient, val);
                    break;
                case DELETED_TIME_:
                    currentClient->MarkDeleted();
                    break;
                case CLIENT_OLD_SECRET_:
                    FillClientOldSecret(*currentClient, val, Decr_);
                    break;
                case ABC_SERVICE_ID_:
                    FillAbcId(*currentClient, val);
                    break;
                default:
                    ythrow yexception() << "unexpected attribute";
            }
        } catch (const std::exception& e) {
            ++Errors_;
            TLog::Error() << "DbFetcher: Failed to fetch client attr"
                          << ". id=" << id
                          << ". type=" << type
                          << ". value='" << val << "'"
                          << ". error: " << e.what();
        }
    }

    void TDbCacheBuilder::AddKeyAttr(ui32 id, int type, const TString& val) {
        const int PUBLIC_KEY_ = 1;
        const int PRIVATE_KEY_ = 2;
        const int TS_ = 3;
        const int DELETED_TIME_ = 4;

        time_t ts = 0;
        TString tmpKey;
        TCryptoKeys::iterator it;
        try {
            switch (type) {
                case PUBLIC_KEY_:
                    tmpKey = GetBinaryKey(val, id, Decr_);
                    Y_ENSURE(tmpKey, "public key is empty");
                    it = Keys_.emplace(id, TCryptoKey(id)).first;
                    it->second.PublicKey = std::move(tmpKey);
                    break;
                case PRIVATE_KEY_:
                    tmpKey = GetBinaryKey(val, id, Decr_);
                    Y_ENSURE(tmpKey, "private key is empty");
                    it = Keys_.emplace(id, TCryptoKey(id)).first;
                    it->second.PrivateKey = std::move(tmpKey);
                    break;
                case TS_:
                    Y_ENSURE(TryIntFromString<10>(val, ts), "Invalid ts");
                    Keys_.emplace(id, TCryptoKey(id)).first->second.Ts = ts;
                    break;
                case DELETED_TIME_:
                    Keys_.emplace(id, TCryptoKey(id)).first->second.IsDeleted = true;
                    break;
                default:
                    break;
            }
        } catch (const std::exception& e) {
            ++Errors_;
            DeletedKeyId_.insert(id);
            TLog::Error() << "DbFetcher: Failed to fetch key attr"
                          << ". id=" << id
                          << ". type=" << type
                          << ". value='" << val << "'"
                          << ". error: " << e.what();
        }
    }

    void TDbCacheBuilder::PostProcess() {
        for (ui32 id : DeletedKeyId_) {
            Keys_.erase(id);
        }

        TDbCache cache;
        cache.Clients = std::make_shared<TClients>(std::move(Clients_));

        for (auto& clIt : *cache.Clients) {
            TClient& cl = *clIt.second;

            for (TKeyId id : cl.KeysIds()) {
                try {
                    auto keyIt = Keys_.find(id);
                    if (keyIt == Keys_.end()) {
                        TLog::Error() << "DbFetcher: Key " << id << " is not found for client " << cl.Id();
                        continue;
                    }

                    const TCryptoKey& key = keyIt->second;
                    if (key.IsDeleted) {
                        continue;
                    }

                    key.ThrowIfInvalid(cl.Id());
                } catch (const std::exception& e) {
                    ++Errors_;
                    TLog::Error() << "Failed to load key " << id << " for client_id=" << cl.Id() << ": " << e.what();
                }
            }
        }

        Result_ = std::move(cache);
    }

    void TDbCacheBuilder::Serialize(ui32 preferedIdx, const TConfig::TPassportIds& passpIds) {
        SetTvmPrivateKey(Result_, Keys_, passpIds.Tvm, preferedIdx);
        SerializePassportKeys(Result_, Keys_, passpIds, MinKeyCount_);
        BuildSecretIndex(Result_);
    }

    TDbCache TDbCacheBuilder::Finalize() {
        return std::move(Result_);
    }

    void TDbCacheBuilder::FillClientName(TClient& cl, const TString& str) {
        cl.SetName(str);
    }

    void TDbCacheBuilder::FillClientSecret(TClient& cl, const TString& str, const NTvmCommon::TDecryptor& decr) {
        TString secret = decr.DecryptAes(str, "secret for client: ", cl.Id());
        Y_ENSURE(secret, "failed to decrypt secret");

        TString decodedSecret = NUtils::Base64url2bin(secret);
        Y_ENSURE(decodedSecret, "bad base64 in secret: '" << secret << "'");

        cl.SetSecret(decodedSecret);
    }

    void TDbCacheBuilder::FillKeyIds(TClient& cl, const TString& str) {
        NUtils::Transform(
            str,
            '|',
            [&cl](const TStringBuf buf) -> void {
                if (buf) {
                    ui32 keyId = 0;
                    Y_ENSURE(TryIntFromString<10>(buf, keyId) && keyId != 0,
                             "invalid key id: '" << buf << "'");
                    cl.AddKeyId(keyId);
                }
            });
    }

    void TDbCacheBuilder::FillClientOldSecret(TClient& cl, const TString& str, const NTvmCommon::TDecryptor& decr) {
        TString secret = decr.DecryptAes(str, "old secret for client: ", cl.Id());
        Y_ENSURE(secret, "failed to decrypt old secret");

        TString decodedSecret = NUtils::Base64url2bin(secret);
        Y_ENSURE(decodedSecret, "bad base64 in old secret: '" << secret << "'");

        cl.SetOldSecret(decodedSecret);
    }

    void TDbCacheBuilder::FillAbcId(TClient& cl, const TString& str) {
        i64 id = 0;
        Y_ENSURE(TryIntFromString<10>(str, id), "abc service id is not int");
        cl.SetAbcId(id);
    }

    void TDbCacheBuilder::SetTvmPrivateKey(TDbCache& cache,
                                           const TCryptoKeys& keys,
                                           TClientId tvmId,
                                           ui32 preferedIdx) {
        auto clIt = cache.Clients->find(tvmId);
        Y_ENSURE(clIt != cache.Clients->end(), "Tvm id is not found in db: " << tvmId);

        NTvmCommon::TPrivateKey::ChooseKey(
            std::vector<TKeyId>(clIt->second->KeysIds()),
            preferedIdx,
            [&keys, &cache, tvmId, this](const TKeyId id) {
                auto it = keys.find(id);
                if (it == keys.end()) {
                    TLog::Error() << "Tvm key " << id << " is missing!";
                    return false;
                }

                try {
                    it->second.ThrowIfInvalid(tvmId);
                    cache.TvmKey = std::make_shared<NTvmAuth::NRw::TRwPrivateKey>(
                        it->second.PrivateKey,
                        it->second.Id);
                    return true;
                } catch (const std::exception& e) {
                    ++Errors_;
                    TLog::Error() << "Tvm key " << id << " is malformed: " << e.what();
                }

                return false;
            });
    }

    void TDbCacheBuilder::SerializePassportKeys(TDbCache& cache,
                                                const TCryptoKeys& keys,
                                                const TConfig::TPassportIds& passpIds,
                                                size_t minimunKeyCount) {
        TPublicKeysBuilder tvmBuilder;
        std::unordered_map<TClientId, TPublicKeysBuilder> bbBuilders;

        auto keysLamda = [&cache, &keys, minimunKeyCount, this](TClientId id, auto addLamda) {
            auto clIt = cache.Clients->find(id);
            Y_ENSURE(clIt != cache.Clients->end(), "Client id is not found in db: " << id);

            size_t count = 0;
            for (ui32 keyId : clIt->second->KeysIds()) {
                try {
                    auto it = keys.find(keyId);
                    Y_ENSURE(it != keys.end(), "missing key");

                    it->second.ThrowIfInvalid(id);
                    cache.KeysAge = std::max(cache.KeysAge, it->second.Ts);
                    addLamda(it->second);
                    ++count;
                } catch (const std::exception& e) {
                    ++Errors_;
                    TLog::Error() << "DbFetcher: bad key " << id << " for client_id=" << id
                                  << ". " << e.what();
                }
            }

            if (count < 14) {
                TLog::Error() << "Suspicious key count " << count << " for client_id " << id;
            }
            Y_ENSURE(count >= minimunKeyCount,
                     "Found only " << count << " keys for client_id " << id << ". need at least: " << minimunKeyCount);
        };

        auto bbLamda = [&keysLamda, &tvmBuilder, &bbBuilders](TClientId id,
                                                              tvm_keys::BbEnvType keysType) {
            keysLamda(id, [&tvmBuilder, &bbBuilders, keysType, id](const TCryptoKey& k) {
                tvmBuilder.AddBbKey(k.Id, k.Ts, keysType, k.PublicKey);
                bbBuilders.insert({id, {}}).first->second.AddBbKey(k.Id, k.Ts, keysType, k.PrivateKey);
            });
        };

        bbLamda(passpIds.BbProd, tvm_keys::Prod);
        bbLamda(passpIds.BbProdYateam, tvm_keys::ProdYateam);
        bbLamda(passpIds.BbTest, tvm_keys::Test);
        bbLamda(passpIds.BbTestYateam, tvm_keys::TestYateam);
        bbLamda(passpIds.BbStress, tvm_keys::Stress);
        keysLamda(passpIds.Tvm, [&tvmBuilder](const TCryptoKey& k) {
            tvmBuilder.AddTvmKey(k.Id, k.Ts, k.PublicKey);
        });

        cache.PublicTvmKeys = std::make_shared<const TString>(tvmBuilder.SerializeV1());
        cache.PrivateBbKeys = std::make_shared<std::unordered_map<TClientId, TString>>();
        for (const auto& b : bbBuilders) {
            cache.PrivateBbKeys->emplace(b.first, b.second.SerializeV1());
        }

        if (0 != passpIds.BbMimino) {
            cache.PrivateBbKeys->emplace(passpIds.BbMimino, cache.PrivateBbKeys->find(passpIds.BbProd)->second);
        }
        if (0 != passpIds.BbLoad) {
            cache.PrivateBbKeys->emplace(passpIds.BbLoad, cache.PrivateBbKeys->find(passpIds.BbStress)->second);
        }

        cache.ServiceCtx = std::make_shared<NTvmAuth::TServiceContext>(
            NTvmAuth::TServiceContext::CheckingFactory(passpIds.Tvm, *cache.PublicTvmKeys));
    }

    TString TDbCacheBuilder::GetBinaryKey(const TString& value,
                                          ui32 id,
                                          const NTvmCommon::TDecryptor& decr) {
        TString encrKey = decr.DecryptAes(value, "key id:", id);
        Y_ENSURE(encrKey, "failed to decrypt");

        TString keyStr = NUtils::Base64url2bin(encrKey);
        Y_ENSURE(keyStr, "bad base64url: " << encrKey);

        return keyStr;
    }

    void TDbCacheBuilder::BuildSecretIndex(TDbCache& cache) {
        cache.SecretIndex = std::make_shared<TSecretIndex>();
        TSecretIndex& idx = *cache.SecretIndex;
        idx.reserve(1.1 * cache.Clients->size());

        for (const auto& pair : *cache.Clients) {
            const TClient& cl = *pair.second;

            idx.insert({cl.Secret(), pair.second});
            if (!cl.OldSecret().empty()) {
                idx.insert({cl.OldSecret(), pair.second});
            }
        }
    }

    TDbFetcher::TDbFetcher(const TConfig& config,
                           NDbPool::TDbPool& db,
                           const TString& keyFile,
                           const TString& diskCache,
                           size_t minKeyCount)
        : DiskCache_(diskCache)
        , Db_(db)
        , Decryptor_(NUtils::ReadFile(keyFile))
        , Conf_(config.DbFetcher)
        , PasspIds_(config.PassportIds)
        , MinKeyCount_(minKeyCount)
    {
        try {
            Run();
        } catch (const std::exception& e) {
            TLog::Error() << "Failed to load data from db. Let's read from disk: " << e.what();
            LoadFromFile();
        }
    }

    void TDbFetcher::AddUnistat(NUnistat::TBuilder& builder) const {
        builder.Add(UnistatQueryErrors_);
        builder.Add(UnistatParsingErrors_);
    }

    TDbFetcher::TResult TDbFetcher::GetClient(TClientId id) const {
        const TCachePtr cache = Cache_.Get();

        auto it = cache->Clients->find(id);
        if (it == cache->Clients->end()) {
            return {};
        }

        return {it->second, cache->TvmKey};
    }

    std::shared_ptr<const TString> TDbFetcher::PublicTvmKeys() const {
        return Cache_.Get()->PublicTvmKeys;
    }

    TRwPrivateKeyPtr TDbFetcher::PrivateKey() const {
        return Cache_.Get()->TvmKey;
    }

    TPrivateKeysPtr TDbFetcher::PrivateBbKeys() const {
        return Cache_.Get()->PrivateBbKeys;
    }

    TServiceCtxPtr TDbFetcher::ServiceCtx() const {
        return Cache_.Get()->ServiceCtx;
    }

    TSecretIndexPtr TDbFetcher::SecretIndex() const {
        return Cache_.Get()->SecretIndex;
    }

    time_t TDbFetcher::KeysAge() const {
        return Cache_.Get()->KeysAge;
    }

    void TDbFetcher::Run() {
        TCachePtr cache = Load();
        Cache_.Set(cache);

        TLog::Info() << "DbFetcher: cache successfully updated. Disk cache updated. Loaded client count "
                     << cache->Clients->size();
    }

    static const TString SELECT_CLIENT_ID =
        "SELECT id, type, value FROM tvm_client_attributes WHERE type IN(1,5,6,18,19,21) ORDER BY id";
    static const TString SELECT_KEYS =
        "SELECT id, type, value FROM tvm_secret_key_attributes WHERE type IN(1,2,3,4) ORDER BY id";
    TDbFetcher::TCachePtr TDbFetcher::Load() const {
        auto fetch = [this](const TString& query) -> NDbPool::TTable {
            try {
                auto res = NDbPool::NUtils::DoQueryTries(Db_, query, Conf_.Retries);
                UnistatQueryErrors_ += res.Retries;
                return res.Result->ExctractTable();
            } catch (...) {
                UnistatQueryErrors_ += Conf_.Retries;
                throw;
            }
        };

        NDbPool::TTable clientsResult = fetch(SELECT_CLIENT_ID);
        TLog::Debug() << "DbFetcher: got " << clientsResult.size()
                      << " rows (attributes) from " << Db_.GetDbInfo().Serialized;

        NDbPool::TTable keysResult = fetch(SELECT_KEYS);
        TLog::Debug() << "DbFetcher: got " << keysResult.size()
                      << " rows (keys) from " << Db_.GetDbInfo().Serialized;

        try {
            disk_cache::Data proto;
            TDbCacheBuilder b(Decryptor_, UnistatParsingErrors_, MinKeyCount_);

            ui32 id;
            ui32 type;
            TString val;
            for (const NDbPool::TRow& row : clientsResult) {
                row.Fetch(id, type, val);

                disk_cache::Row* r = proto.add_clients();
                r->set_id(id);
                r->set_type(type);
                r->set_value(val);

                b.AddClientAttr(id, type, val);
            }
            for (const NDbPool::TRow& row : keysResult) {
                row.Fetch(id, type, val);

                disk_cache::Row* r = proto.add_keys();
                r->set_id(id);
                r->set_type(type);
                r->set_value(val);

                b.AddKeyAttr(id, type, val);
            }

            b.PostProcess();
            b.Serialize(Conf_.PreferedPrivateKeyIdx, PasspIds_);

            TUtils::WriteFileViaTmp(DiskCache_, proto.SerializeAsString());
            return std::make_shared<TDbCache>(b.Finalize());
        } catch (const std::exception& e) {
            ++UnistatParsingErrors_;
            TLog::Error() << "DbFetcher: Failed to fetch client info: " << e.what();
            throw;
        }
    }

    void TDbFetcher::LoadFromFile() {
        try {
            TDbCacheBuilder b(Decryptor_, UnistatParsingErrors_, MinKeyCount_);
            Y_ENSURE(NFs::Exists(DiskCache_),
                     "Disk cache does not exist: " << DiskCache_);

            TFileInput fileInput(DiskCache_.c_str());
            TString s = fileInput.ReadAll();
            disk_cache::Data proto;
            Y_ENSURE(proto.ParseFromString(s), "Failed to parse proto from disk cache");

            for (int idx = 0; idx < proto.clients_size(); ++idx) {
                b.AddClientAttr(proto.clients(idx).id(),
                                proto.clients(idx).type(),
                                TString(proto.clients(idx).value()));
            }
            for (int idx = 0; idx < proto.keys_size(); ++idx) {
                b.AddKeyAttr(proto.keys(idx).id(),
                             proto.keys(idx).type(),
                             TString(proto.keys(idx).value()));
            }

            b.PostProcess();
            b.Serialize(Conf_.PreferedPrivateKeyIdx, PasspIds_);

            Cache_.Set(std::make_shared<TDbCache>(b.Finalize()));
            TLog::Info() << "DbFetcher: TVM loaded data from disk cache instead of db";
        } catch (const std::exception& e) {
            TLog::Error() << "DbFetcher: Failed to fetch client info: " << e.what();
            throw;
        }
    }
}
