#include "pinger.h"

#include "handle.h"

#include <passport/infra/libs/cpp/dbpool/exception.h>
#include <passport/infra/libs/cpp/dbpool/result.h>

#include <util/stream/format.h>

namespace NPassport::NDbPool {
    TPingGuard::TPingGuard(THost& host, std::unique_ptr<THandle> h)
        : Host(host)
        , Handle(std::move(h))
    {
    }

    TPingGuard::~TPingGuard() {
        if (Handle) {
            Host.Put(std::move(Handle));
        }
    }

    TPingableHost::TPingableHost(TDbPoolLog log,
                                 TDuration pingPeriod,
                                 TDuration queryTimeout,
                                 std::reference_wrapper<THost> host)
        : Log_(log)
        , PingPeriod_(pingPeriod)
        , QueryTimeout_(queryTimeout)
        , Host_(host)
    {
    }

    const TDbInfo& TPingableHost::DbInfo() const {
        return Host_.get().GetDbInfo();
    }

    TDuration TPingableHost::PingNonblocking(const TManualEvent& wakeUpServiceThread) {
        TDuration toSleep = PingPeriod_;

        try {
            if (Guard_) {
                // Finish ping from previous iteration
                if (IsPingFinished()) {
                    LastPing_ = TInstant::Now();
                } else {
                    toSleep = GetGuardDeadline() - TInstant::Now();
                }
            } else {
                if (LastPing_ + PingPeriod_ < TInstant::Now()) {
                    StartPing(wakeUpServiceThread);
                    toSleep = QueryTimeout_;
                } else {
                    toSleep = LastPing_ + PingPeriod_ - TInstant::Now();
                }
            }
        } catch (const std::exception& e) {
            Log_.Debug() << DbInfo()
                         << " Pinger for host: error <" << e.what()
                         << "> id=" << DbInfo().DisplayName;
            LastPing_ = TInstant::Now();
            Guard_.reset();
            Host_.get().OnPingException();
        }

        return toSleep;
    }

    void TPingableHost::Wait() {
        if (Guard_) {
            // Finish ping
            Guard_->Handle->WaitToFinish();
        }
    }

    void TPingableHost::StartPing(const TManualEvent& wakeUpServiceThread) {
        Y_VERIFY(!Guard_);

        Guard_.emplace(Host_.get(), Host_.get().TryGetNonblocking());

        if (!Guard_->Handle) {
            throw TCantGetConnection(DbInfo()) << "Can't get connection";
        }

        Guard_->Handle->StartNonblockingPing([ev = wakeUpServiceThread](const auto&) mutable {
            ev.Signal();
        });
    }

    bool TPingableHost::IsPingFinished() {
        Y_VERIFY(Guard_);

        std::unique_ptr<TResult> result = Guard_->Handle->CheckPingFinished();
        if (!result) {
            return false;
        }

        Host_.get().ForceDown(!result->ToPingResult().IsOk);

        Guard_.reset();
        return true;
    }

    TInstant TPingableHost::GetGuardDeadline() const noexcept {
        Y_VERIFY(Guard_);
        return Guard_->Start + QueryTimeout_;
    }

    TPingableHost::~TPingableHost() = default;

    TPinger::TPinger(TDbPoolLog log,
                     TDuration pingPeriod,
                     TDuration queryTimeout,
                     const THostRefs& hosts)
        : Log_(log)
    {
        for (const std::reference_wrapper<THost>& h : hosts) {
            Hosts_.push_back(std::make_unique<TPingableHost>(
                Log_,
                pingPeriod,
                queryTimeout,
                h));
        }
    }

    TPinger::~TPinger() {
        try {
            WaitAll();
        } catch (...) {
        }
    }

    TDuration TPinger::PingNonblocking(const TManualEvent& wakeUpServiceThread) {
        TDuration toSleep = TDuration::Max();

        for (std::unique_ptr<TPingableHost>& h : Hosts_) {
            try {
                TDuration s = h->PingNonblocking(wakeUpServiceThread);
                toSleep = std::min(toSleep, s);
            } catch (const std::exception& e) {
                // Impossible case
                Log_.Debug() << h->DbInfo()
                             << " Pinger: error <" << e.what()
                             << "> id=" << h->DbInfo().DisplayName;
            }
        }

        return toSleep;
    }

    void TPinger::WaitAll() {
        for (std::unique_ptr<TPingableHost>& h : Hosts_) {
            try {
                h->Wait();
            } catch (const std::exception& e) {
                Log_.Debug() << h->DbInfo()
                             << " Pinger: error on waiting <" << e.what()
                             << "> id=" << h->DbInfo().DisplayName;
            }
        }
    }
}
