#include "db_pool.h"

#include "choose_policy.h"
#include "handle.h"
#include "host.h"
#include "pinger.h"
#include "producer.h"

#include <passport/infra/libs/cpp/dbpool/db_pool_ctx.h>
#include <passport/infra/libs/cpp/dbpool/exception.h>

#include <passport/infra/libs/cpp/json/writer.h>
#include <passport/infra/libs/cpp/unistat/time_stat.h>

#include <util/stream/format.h>
#include <util/system/thread.h>

#include <memory>
#include <thread>

namespace NPassport::NDbPool {
    static THostRefs ToRefs(const std::vector<std::unique_ptr<THost>>& hosts) {
        THostRefs res;

        for (const std::unique_ptr<THost>& h : hosts) {
            res.push_back(std::ref(*h));
        }

        return res;
    }

    static size_t CalcWeightSum(const std::vector<TDbHost>& hosts) {
        size_t res = 0;

        for (const TDbHost& d : hosts) {
            res += d.Weight;
        }

        return res;
    }

    TDbPoolImpl::TDbPoolImpl(const TDbPoolSettings& settings,
                             const TDbPoolLog& logger)
        : Settings_(settings)
        , DbInfo_(*Settings_.Dsn)
        , WeightSum_(CalcWeightSum(Settings_.Hosts))
        , Exit_(false)
        , Log_(logger)
        , State_(Log_, TDuration(), DbInfo_.Serialized + ";id=" + DbInfo_.DisplayName)
        , Counters_(std::make_shared<TCounters>(Settings_.Dsn->DisplayName, Settings_.Size))
        , TimeStats_(THost::MakeStats(*Settings_.Dsn, Settings_.Dsn->DisplayName))
        , ChooseFactory_(TRandomFactory::Create(Settings_.Hosts))
    {
        Log_.Debug() << DbInfo_ << " Start dbpool";

        InitHosts();

        Thread_ = std::thread([this]() {
            TThread::SetCurrentThreadName(("db_service_" + Settings_.Dsn->Driver).c_str());
            ServiceThread();
        });
    }

    TDbPoolImpl::~TDbPoolImpl() {
        Log_.Debug() << DbInfo_ << " Stop dbpool";

        Exit_.store(true);

        WakeUpServiceThread_.Signal();
        Thread_.join();
    }

    const TDbInfo& TDbPoolImpl::GetDbInfo() const {
        return DbInfo_;
    }

    bool TDbPoolImpl::IsOk(TString* statusBuf) const {
        if (statusBuf) {
            statusBuf->reserve(512);
        }

        THostIdxs badHosts;
        const size_t aliveWeight = CalculateState(badHosts, ECollectHosts::Bad);

        const bool res = TPoolState::IsOk(State_, DbInfo_, statusBuf);

        if (!statusBuf) {
            return res;
        }

        NUtils::Append(*statusBuf,
                       ": ", (Hosts_.size() - badHosts.size()), "/", Hosts_.size(),
                       " [", aliveWeight, "/", WeightSum_, "]");

        if (badHosts.empty()) {
            return res;
        }

        NUtils::Append(*statusBuf, ";down hosts:[");

        for (size_t idx : badHosts) {
            if (statusBuf->back() != '[') {
                statusBuf->push_back(',');
            }
            statusBuf->append(Hosts_[idx]->GetDbHost().Host);
        }

        statusBuf->push_back(']');

        return res;
    }

    THandlesInfo TDbPoolImpl::GetHandlesInfo() const {
        THandlesInfo res;

        for (const std::unique_ptr<THost>& h : Hosts_) {
            res += h->GetHandlesInfo();
        }

        return res;
    }

    void TDbPoolImpl::TryPing() {
        size_t aliveWeight = 0;

        for (std::unique_ptr<THost>& h : Hosts_) {
            try {
                h->TryPing(Settings_.GetTimeout);

                aliveWeight += h->GetDbHost().Weight;
                Log_.Debug() << h->GetDbInfo()
                             << " tryPing succeed. id=" << h->GetDbInfo().DisplayName;
            } catch (const std::exception& e) {
                Log_.Warning() << h->GetDbInfo()
                               << "  tryPing failed: " << e.what()
                               << ". id=" << h->GetDbInfo().DisplayName;
            }
        }

        // to update host sizes
        WakeUpServiceThread_.Signal();

        const double succeedRate = double(aliveWeight) / WeightSum_;
        if (succeedRate < Settings_.BalancingOpenRate) {
            throw TException(DbInfo_)
                << "DbPool failed to ping: too few alive weight: "
                << aliveWeight << "/" << WeightSum_
                << " (rate=" << Prec(succeedRate, 3)
                << "). Min required: " << Settings_.BalancingOpenRate;
        }
    }

    void TDbPoolImpl::GetExtendedStats(NJson::TObject& out) const {
        NJson::TObject db(out, DbInfo_.DisplayName);

        THandlesInfo counters;
        {
            NJson::TObject hosts(db, "hosts");
            for (const std::unique_ptr<THost>& h : Hosts_) {
                counters += h->GetExtendedStats(hosts);
            }
        }

        TString err;
        const bool ok = IsOk(&err);
        db.Add("status", ok ? "OK" : "ERROR");
        db.Add("error", ok ? TString() : err);

        db.Add("config_size", counters.ConfigCount);
        db.Add("workers_idle", counters.Idle);
        db.Add("workers_in_use", counters.InUse);
        db.Add("workers_dying", counters.Dying);
    }

    void TDbPoolImpl::AddUnistat(NUnistat::TBuilder& builder) const {
        Counters_->AddUnistat(builder);
        TimeStats_->AddUnistat(builder);
    }

    void TDbPoolImpl::AddUnistatExtended(NUnistat::TBuilder& builder) const {
        for (const std::unique_ptr<THost>& h : Hosts_) {
            h->AddUnistatExtended(builder);
        }
    }

    std::unique_ptr<THandle> TDbPoolImpl::Get() {
        const TInstant deadline = TInstant::Now() + Settings_.GetTimeout;

        while (TInstant::Now() < deadline) {
            TChooserPtr chooser = ChooseFactory_->CreateChooser();

            // Try get handle from any available host
            IChooser::TOptionalIdx idx;
            while ((idx = chooser->TryGetIdx())) {
                Y_VERIFY(*idx < Hosts_.size());
                THost& host = *Hosts_[*idx];

                if (!host.IsOk()) {
                    continue;
                }

                std::unique_ptr<THandle> handle = host.TryGetNonblocking();
                if (handle) {
                    return handle;
                }
            }

            // We use sleep instead of signal through condvar + mutex because:
            // 1. it is rare case
            // 2. it would add more complexity - to send signal from Producer
            // 3. it would be too expensive to send signal on put() - it is hot path
            Sleep(TDuration::MicroSeconds(500));
        }

        // Keep track of contiguous periods of failures
        // Pinger also uses this logic
        Log_.Debug() << DbInfo_ << " get() failed";
        ++Counters_->CantGetConnection;

        return nullptr;
    }

    void TDbPoolImpl::Put(std::unique_ptr<THandle> handle) {
        size_t idx = handle->GetHostIdx();
        Y_VERIFY(idx < Hosts_.size(), "Internal error: invalid idx: %ld", idx);

        switch (Hosts_[idx]->Put(std::move(handle))) {
            case EPutHandleStatus::Ok:
                break;
            case EPutHandleStatus::Bad:
                WakeUpServiceThread_.Signal();
                break;
        }
    }

    void TDbPoolImpl::InitHosts() {
        Y_ENSURE(!Settings_.Hosts.empty(), "dbpool cannot be empty");

        THostIdxs allHosts;
        allHosts.resize(Settings_.Hosts.size());
        std::iota(allHosts.begin(), allHosts.end(), 0);

        THostSizes hostSizes;
        CalculateSizes(Settings_, allHosts, WeightSum_, hostSizes);

        for (size_t idx = 0; idx < Settings_.Hosts.size(); ++idx) {
            const TDbHost& d = Settings_.Hosts[idx];
            const size_t initialSize = hostSizes[idx];

            Hosts_.push_back(std::make_unique<THost>(
                THostSettings{
                    .Dsn = Settings_.Dsn,
                    .DbHost = d,
                    .ConnectionTimeout = Settings_.ConnectionTimeout,
                    .QueryTimeout = Settings_.QueryTimeout,
                    .FailThreshold = Settings_.FailThreshold,
                    .TimeToInit = Settings_.TimeToInit,
                    .InitialSize = initialSize,
                    .FetchStatusOnPing = Settings_.FetchStatusOnPing,
                    .DefaultQueryOpts = Settings_.DefaultQueryOpts,
                },
                THostUnistatCtx{
                    .PoolTimeStats = TimeStats_,
                    .PoolCounters = Counters_,
                },
                Log_,
                idx));
        }
    }

    void TDbPoolImpl::CalculateSizes(const TDbPoolSettings& settings,
                                     const THostIdxs& indexes,
                                     size_t weightSum,
                                     TDbPoolImpl::THostSizes& out) {
        out.clear();
        out.reserve(indexes.size());

        // Allowes to provide at least poolsize overall hosts
        size_t sum = 0;

        // Allowes to make uniform distribution of handles among hosts
        double accumulatedSize = 0;

        for (size_t idx : indexes) {
            Y_VERIFY(weightSum > 0);

            accumulatedSize += settings.Hosts[idx].Weight * double(settings.Size) / weightSum;
            size_t newValue = size_t(accumulatedSize);
            accumulatedSize -= newValue;

            newValue = std::max(newValue, size_t(1));
            out.push_back(newValue);

            sum += newValue;
        }

        // To avoid missing some handles count from config
        if (!out.empty() && sum < settings.Size) {
            out.back() += settings.Size - sum;
        }
    }

    size_t TDbPoolImpl::CalculateState(THostIdxs& hosts, ECollectHosts collect) const {
        size_t aliveWeight = 0;

        hosts.reserve(Hosts_.size());
        for (size_t idx = 0; idx < Hosts_.size(); ++idx) {
            if (!Hosts_[idx]->IsOk()) {
                if (collect == ECollectHosts::Bad) {
                    hosts.push_back(idx);
                }
                continue;
            }

            aliveWeight += Hosts_[idx]->GetDbHost().Weight;
            if (collect == ECollectHosts::Alive) {
                hosts.push_back(idx);
            }
        }

        const double aliveRate = double(aliveWeight) / WeightSum_;

        if (aliveRate < Settings_.BalancingCloseRate) {
            State_.TryMake(TPoolState::Down);
        } else if (Settings_.BalancingOpenRate < aliveRate) {
            State_.TryMake(TPoolState::Up);
        }

        return aliveWeight;
    }

    void TDbPoolImpl::UpdateHostSizes() {
        THostIdxs aliveHosts;
        size_t aliveWeight = CalculateState(aliveHosts, ECollectHosts::Alive);

        THostSizes hostSizes;
        CalculateSizes(Settings_, aliveHosts, aliveWeight, hostSizes);

        for (size_t i = 0; i < aliveHosts.size(); ++i) {
            const size_t idx = aliveHosts[i];
            const size_t newValue = hostSizes[i];

            Hosts_[idx]->SetHandlesCount(newValue);
        }
    }

    void TDbPoolImpl::ServiceThread() {
        Log_.Debug() << DbInfo_ << " DbPool->service: thread entry. id=" << DbInfo_.DisplayName;

        // Start populating the pool
        TProducer producer(Log_, Settings_, ToRefs(Hosts_));

        // Need this to prevent them from becoming stale due to long inactivity
        TPinger pinger(Log_, Settings_.PingPeriod, Settings_.QueryTimeout, ToRefs(Hosts_));

        while (!Exit_.load(std::memory_order_relaxed)) {
            TDuration toSleep = TDuration::Max();

            try {
                UpdateHostSizes();
            } catch (const std::exception& e) {
                // Impossible case
                Log_.Debug() << DbInfo_ << " DbPool->service: State fatal error: <"
                             << e.what() << ">. id=" << DbInfo_.DisplayName;
            }

            try {
                TDuration s = producer.ProduceNonblocking(WakeUpServiceThread_);
                toSleep = std::min(toSleep, s);
            } catch (const std::exception& e) {
                // Impossible case
                Log_.Debug() << DbInfo_ << " DbPool->service: Producer fatal error: <" << e.what()
                             << ">. id=" << DbInfo_.DisplayName;
            }

            try {
                TDuration s = pinger.PingNonblocking(WakeUpServiceThread_);
                toSleep = std::min(toSleep, s);
            } catch (const std::exception& e) {
                // Impossible case
                Log_.Debug() << DbInfo_ << " DbPool->service: Pinger fatal error: <" << e.what()
                             << ">. id=" << DbInfo_.DisplayName;
            }

            WakeUpServiceThread_.WaitT(toSleep);
            WakeUpServiceThread_.Reset();
        }

        Log_.Debug() << DbInfo_ << " DbPool->service: exiting. id=" << DbInfo_.DisplayName;
    }
}
