#include "keyring.h"

#include "random.h"
#include "sessionutils.h"

#include <passport/infra/libs/cpp/dbpool/handle.h>
#include <passport/infra/libs/cpp/dbpool/util.h>
#include <passport/infra/libs/cpp/dbpool/value.h>
#include <passport/infra/libs/cpp/juggler/status.h>
#include <passport/infra/libs/cpp/utils/log/global.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <memory>

namespace NPassport::NAuth {
    static const size_t RETRIES = 3;

    TKeyRing::TKeyRing(NDbPool::TDbPool& dbp, const TString& keyspace, const TKeyRingSettings& settings, TDuration timeout)
        : Dbp_(dbp)
        , Keyspace_(keyspace)
        , Settings_(settings)
        , Signkey_(Keys_.end())
        , Timeout_(timeout)
    {
        LastError_.Set(std::make_shared<TString>());
        // for web keyspaces (i.e. NOT cookiel, oauth, signsmth) build domain suffix
        if (Keyspace_.find('_') != TString::npos) {
            DomSuff_ = Keyspace_;
            TSessionUtils::KeyspaceToDomain(DomSuff_);
        }
        // Try load keyring. Regardless of the outcome, register
        // ourselves with the syncing thread which will keep trying
        // to sync us with the DB periodically
        LookupKeyTable();
        if (!TrySyncKeyRing()) {
            throw yexception() << "Couldn't load keyring at startup: " << keyspace;
        }
    }

    // we took care of SessionKey * deletion in KeyMapCleaner declaration - see ctor
    TKeyRing::~TKeyRing() = default;

    TKeyRing::TRandomPtr TKeyRing::GetRandomById(const TStringBuf id) const {
        // acquire read-write mutex for reading - the only writer is
        // syncing thread
        {
            std::shared_lock lck(Mutex_);

            TKeyMap::const_iterator kcit = Keys_.find(id);
            if (kcit != Keys_.end()) {
                return kcit->second;
            }
        }

        return {};
    }

    TKeyRing::TRandomPtr TKeyRing::GetRandomForSign() const {
        std::shared_lock mutex(Mutex_);
        if (Signkey_ != Keys_.end()) {
            return Signkey_->second;
        }
        return {};
    }

    TKeyWithGamma TKeyRing::GetKeyById(const TStringBuf id, TRandom::EView randomView) const {
        ui32 number = 0;
        if (!TryIntFromString<10>(id, number)) {
            return TKeyWithGamma::FromError("id is not nummber: '", id, "'");
        }

        const TCombinedKeyId combinedId(number);

        TRandomPtr key = GetRandomById(ToString(combinedId.RandomId()));
        if (!key) {
            return TKeyWithGamma::FromError("missing keyid=", id);
        }

        return Settings_.GammaKeeper->GetForCheck(
            Keyspace_,
            *key,
            combinedId.GammaId(),
            randomView);
    }

    TKeyWithGamma TKeyRing::GetKeyForSign(TRandom::EView randomView) const {
        TRandomPtr key = GetRandomForSign();
        if (!key) {
            return TKeyWithGamma::FromError("missing key for signing");
        }

        return Settings_.GammaKeeper->GetForSign(Keyspace_, *key, randomView);
    }

    TInstant
    TKeyRing::GetMostRecentStartTime() const {
        std::shared_lock lock(Mutex_);
        return Keys_.empty() ? TInstant::Seconds(0) : Keys_.rbegin()->second->StartTime();
    }

    NJuggler::TStatus TKeyRing::GetJugglerStatus() const {
        TDuration elapsed = TInstant::Now() - GetMostRecentStartTime();
        if (elapsed > Timeout_) {
            return NJuggler::TStatus(NJuggler::ECode::Warning,
                                     "Keys for keyring '", Keyspace_, "' weren't updated for ", elapsed.Minutes(),
                                     "min. last error: ", *LastError_.Get());
        }

        return {};
    }

    void TKeyRing::PrintAsTxt(IOutputStream& out) const {
        std::shared_lock mutex(Mutex_);

        // format for lrandoms.txt
        for (auto it = Keys_.rbegin(); it != Keys_.rend(); ++it) {
            const TRandom& key = *it->second;

            out << key.Id() << ";";
            out << key.GetBody() << ";";
            out << key.StartTime().Seconds() << "\n";
        }
    }

    //
    // Helper routines used once for every keyspace for initial
    // loading of the config: keys table name and addprefix flag.
    //
    // Throws on error.
    //
    static const TString SQL_QUERY("select groupid, tablename from keyspaces where inuse<>0 and domainsuff='");

    void TKeyRing::LookupKeyTable() {
        // Load keyspace config -- table name and addprefix flag
        NDbPool::NUtils::TResponseWithRetries response =
            NDbPool::NUtils::DoQueryTries(Dbp_, SQL_QUERY + Keyspace_ + "'", RETRIES);

        if (!response.Result->Fetch(Groupid_, Tablename_)) {
            throw yexception() << "Failed to load keyspace name for domain suffix " << Keyspace_;
        }
        TLog::Debug() << "KeyRing: " << Tablename_
                      << ": table name loaded. Took " << response.LastResponseTime;

        // Format other queries depending on the keyspace
        SqlLoadAllKeys_.assign("select id, keybody, UNIX_TIMESTAMP(start) from ");
        SqlLoadAllKeys_.append(Tablename_).append(" order by id asc");
        SqlLoadMinMax_.assign("select min(id), max(id), count(id) from ");
        SqlLoadMinMax_.append(Tablename_);
    }

    std::unique_ptr<NDbPool::TResult> TKeyRing::FetchNewKeys() {
        // first, check for any changes
        NDbPool::NUtils::TResponseWithRetries response =
            NDbPool::NUtils::DoQueryTries(Dbp_, SqlLoadMinMax_, RETRIES);
        TLog::Debug() << "KeyRing: " << Tablename_
                      << ": min-max key id loaded. Took " << response.LastResponseTime;

        const NDbPool::TTable table = response.Result->ExctractTable();
        if (table.empty()) {
            throw yexception() << "Failed to execute: " << SqlLoadMinMax_;
        }

        const NDbPool::TRow& row = table[0];
        int minid = 0;
        int maxid = 0;
        int newcount = 0;
        if (!row[0].IsNull()) {
            minid = row[0].AsInt();
            maxid = row[1].AsInt();
            newcount = row[2].AsInt();
        }
        if (minid == OldestKeyId_ && maxid == NewestKeyId_ && newcount == Lastcount_) {
            return {};
        }

        response = NDbPool::NUtils::DoQueryTries(Dbp_, SqlLoadAllKeys_, RETRIES);
        TLog::Debug() << "KeyRing: " << Tablename_
                      << ": new keys loaded. Took " << response.LastResponseTime;

        return std::move(response.Result);
    }

    TString TKeyRing::GenDebugInfo() const {
        return NUtils::CreateStr(Dbp_.GetDbInfo().Serialized, ", table=", Tablename_, ")");
    }

    //
    // All keyring syncing magic happens here.
    //
    // This function is called synchronously by keyring constructor
    // (before newly constructed keyring is registered with the syncing
    // thread) and then periodically from the syncing thread.
    //
    bool TKeyRing::TrySyncKeyRing() {
        try {
            // NOTE: Don't mutex here because there is no concurrency: client
            // threads only try this once before registering keyring with the
            // worker thread.

            std::unique_ptr<NDbPool::TResult> dbResult = FetchNewKeys();
            if (!dbResult) {
                return true; // everything is up to date
            }

            // yes, something has changed in the database, need to re-sync;
            // retrieve current keys one by one from the database, check them
            // against previously loaded keys and put those that pass the check
            // into a new storage
            int ringsize = dbResult->size();
            TKeyMap new_keys;

            // compare old keys and new keys: drop keys that "have re-appeared"
            // than the oldest current key and those that have different
            // bodies for the same IDs
            TKeyMap::iterator current = Keys_.begin();

            TString new_id;
            TString keybody;
            time_t start;

            bool gotKey = dbResult->Fetch(new_id, keybody, start);

            while (current != Keys_.end() && gotKey) {
                bool advance_old = false;
                bool advance_new = false;

                int curold = current->second->IntId();
                int curnew = IntFromString<int, 10>(new_id);

                if (curold < curnew) {
                    // if current old key is older than current new key,
                    // skip it
                    advance_old = true;

                    TLog::Info("Drop key with id=%d", curold);

                } else if (curold == curnew) {
                    // check to see that noybody has tampered with this keys:
                    // if key body has changed, we skip this key and will always
                    // skip it; otherwise, insert this key in the new map
                    if (current->second->GetBody() != keybody) {
                        TLog::Warning("Key body changed in table %s for key with id %s",
                                      Tablename_.c_str(),
                                      new_id.c_str());
                    } else {
                        new_keys.insert(TKeyMap::value_type(new_id, current->second));
                    }

                    advance_old = true;
                    advance_new = true;

                } else {
                    // never use a key that has "resurrected": that is, belongs in
                    // between of two old keys: once akey is deleted, its id may not
                    // be reused
                    TLog::Warning("Key reappeared in table %s with id %s",
                                  Tablename_.c_str(),
                                  new_id.c_str());

                    advance_new = true;
                }

                // advance either one or both keyring iterators
                if (advance_old) {
                    ++current;
                }
                if (advance_new) {
                    gotKey = dbResult->Fetch(new_id, keybody, start);
                }
            }

            int keysloaded = 0;
            TString first_new_key;
            TString last_new_key;

            if (gotKey) {
                first_new_key = new_id;
            }

            // if new key(s) appeared, add them
            while (gotKey) {
                // Don't log each key since it's too verbose
                last_new_key = new_id;

                TRandomPtr sap = std::make_shared<TRandom>(new_id, keybody, start);
                new_keys.insert(TKeyMap::value_type(last_new_key, std::move(sap)));
                gotKey = dbResult->Fetch(new_id, keybody, start);

                ++keysloaded;
            }

            // Log keysloaded
            TLog::Debug("Loaded %d new key(s) for keyspace '%s'. Key ids: %s-%s",
                        keysloaded,
                        Keyspace_.c_str(),
                        first_new_key.c_str(),
                        last_new_key.c_str());

            // now swap keys and set up variables anew; old keys are
            // now in the new_keys container which will be cleaned up
            // automatically when goes out of scope
            std::unique_lock swapLock(Mutex_);
            Keys_ = std::move(new_keys);
            Lastcount_ = ringsize;
            if (Keys_.empty()) {
                OldestKeyId_ = NewestKeyId_ = 0;
                Signkey_ = Keys_.end();
            } else {
                OldestKeyId_ = Keys_.begin()->second->IntId();
                NewestKeyId_ = Keys_.rbegin()->second->IntId();
                Signkey_ = --(Keys_.end());
                int ctr = Settings_.Signkeydepth;
                while (Signkey_ != Keys_.begin() && --ctr > 0) {
                    --Signkey_;
                }
            }

            LastError_.Set(std::make_shared<TString>());
            return true;
        } catch (std::exception& e) {
            TString error = NUtils::CreateStr("Failed to load random key from db: ",
                                              e.what(), " : ", GenDebugInfo().c_str());
            TLog::Warning() << error;
            LastError_.Set(std::make_shared<TString>(std::move(error)));
        }

        return false;
    }

}
