#include "handle.h"

#include "driver.h"

#include <passport/infra/libs/cpp/dbpool/destination.h>
#include <passport/infra/libs/cpp/dbpool/exception.h>
#include <passport/infra/libs/cpp/dbpool/query.h>
#include <passport/infra/libs/cpp/dbpool/result.h>

#include <passport/infra/libs/cpp/unistat/time_stat.h>
#include <passport/infra/libs/cpp/utils/thread_local_id.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <util/datetime/base.h>
#include <util/generic/string.h>
#include <util/stream/format.h>
#include <util/system/thread.h>

#include <atomic>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <thread>

#include <sys/syscall.h>

namespace NPassport::NDbPool {
    class TWorkerShared {
    public:
        TWorkerShared(TDestinationPtr dsn,
                      const TDbHost& dbHost,
                      const TDbInfo& dbInfo,
                      TDbPoolLog log,
                      const TDuration connectTimeout,
                      const TDuration queryTimeout,
                      bool fetchStatusOnPing,
                      THandleUnistatCtx&& unistat,
                      TString&& threadName)
            : Dsn(std::move(dsn))
            , DbHost(dbHost)
            , DbInfo(dbInfo)
            , Exit(false)
            , Status(TWorkerShared::InProgress)
            , Log(log)
            , StartTime(TInstant::Now())
            , ConnectTimeout(connectTimeout)
            , QueryTimeout(queryTimeout)
            , FetchStatusOnPing(fetchStatusOnPing)
            , Unistat(std::move(unistat))
            , CountHolder(Unistat.Counters->TotalHandles)
            , PoolCountHolder(Unistat.PoolCounters->TotalHandles)
            , ThreadName(std::move(threadName))
            , ConnectedPromise(NThreading::NewPromise<THandleInitError>())
        {
            Randid = IntToString<10>(rand() % 1000000l);
        }

        enum EStatus {
            InProgress,
            Success,
            Failed,
        };

        TDestinationPtr Dsn;
        TDbHost DbHost;
        TDbInfo DbInfo;
        std::mutex Mutex;
        std::condition_variable QueryCond;
        std::atomic_bool Exit;
        std::atomic<EStatus> Status;
        bool IsPing = false;
        TQuery Query;
        std::unique_ptr<IDriver> Sql;
        TString Randid;
        TDbPoolLog Log;
        TString RequestId;
        pid_t Tid = 0;
        TInstant StartTime;
        const TDuration ConnectTimeout;
        TDuration QueryTimeout;
        bool FetchStatusOnPing = false;
        THandleUnistatCtx Unistat;
        const TSelfCounter CountHolder;
        const TSelfCounter PoolCountHolder;
        const TString ThreadName;
        NThreading::TPromise<THandleInitError> ConnectedPromise;
        NThreading::TPromise<TResponse> ResponsePromise;
        NThreading::TFuture<TResponse> ResponseFuture; // TODO - store it on client side
    };

    THandle::TNonblockingInit THandle::CreateNonblocking(
        const THandleSettings& settings,
        THandleUnistatCtx unistat,
        TDbPoolLog logger,
        size_t hostIdx) {
        // struct is required to make clear: what you should do later
        TNonblockingInit res{
            .Handle =
                std::unique_ptr<THandle>(new THandle(
                    settings,
                    unistat,
                    logger)),
        };

        res.Handle->HostIdx_ = hostIdx;
        res.Randid = res.Handle->Shared_->Randid;
        res.InitionError = res.Handle->Shared_->ConnectedPromise.GetFuture();

        return res;
    }

    THandle::~THandle() = default;

    const TDbInfo& THandle::GetDbInfo() const {
        return DbInfo_;
    }

    size_t THandle::GetHostIdx() const {
        return HostIdx_;
    }

    bool THandle::Ping() {
        StartNonblockingPing();
        return WaitResult()->ToPingResult().IsOk;
    }

    void THandle::StartNonblockingPing(TSubscribeFunc func) {
        Shared_->IsPing = true;

        NThreading::TFuture<TResponse> f = NonBlockingQueryImpl();
        if (func) {
            f.Subscribe(std::move(func));
        }
    }

    TString THandle::EscapeQueryParam(const TStringBuf s) {
        return Shared_->Sql->EscapeQueryParam(s);
    }

    std::unique_ptr<TResult> THandle::Query(TQuery&& query, std::optional<TDuration> customTimeout) {
        NonBlockingQuery(std::move(query), customTimeout);
        return WaitResult();
    }

    void THandle::NonBlockingQuery(TQuery&& query, std::optional<TDuration> customTimeout) {
        if (customTimeout) {
            Y_VERIFY(Shared_->Dsn->Driver != "mysql", "custom timeout is not supported in mysql");
        }

        Shared_->Query = std::move(query);
        for (const TQueryOpt& opt : *DefaultQueryOpts_) {
            opt(Shared_->Query);
        }

        NonBlockingQueryImpl(customTimeout);
    }

    NThreading::TFuture<TResponse> THandle::NonBlockingQueryImpl(std::optional<TDuration> customTimeout) {
        ++Shared_->Unistat.Counters->AllRequests;
        ++Shared_->Unistat.PoolCounters->AllRequests;

        TWorkerShared::EStatus status = Shared_->Status.load(std::memory_order_relaxed);
        Y_ENSURE(status == TWorkerShared::Success,
                 "Handle is not ready for new query: " << (int)status);

        std::unique_lock lock(Shared_->Mutex);
        // Invoke real query in the thread by placing query string the
        // in the shared_ area, updating status and raising signal.
        Shared_->ResponsePromise = NThreading::NewPromise<TResponse>();
        Shared_->ResponseFuture = Shared_->ResponsePromise.GetFuture();

        Shared_->RequestId = NUtils::GetThreadLocalRequestId();
        Shared_->StartTime = TInstant::Now();
        Shared_->QueryTimeout = customTimeout ? *customTimeout : DefaultQueryTimeout_;
        Shared_->Status.store(TWorkerShared::InProgress, std::memory_order_release);
        Shared_->QueryCond.notify_one();

        return Shared_->ResponseFuture;
    }

    std::unique_ptr<TResult> THandle::WaitResult(std::optional<TDuration> customTimeout) {
        Badbit_ = true; // we will clean this flag if query will not fail

        const TDuration timeout = customTimeout ? *customTimeout : Shared_->QueryTimeout;
        if (!Shared_->ResponseFuture.Wait(Shared_->StartTime + timeout)) {
            ++Shared_->Unistat.Counters->QueryTimeout;
            ++Shared_->Unistat.PoolCounters->QueryTimeout;
            throw TTimeoutException(Shared_->DbInfo) << ErrMsgTimeout(timeout);
        }

        TResponse response = Shared_->ResponseFuture.ExtractValue();
        if (!response.IsOk) {
            Log_.Debug() << DbInfo_ << " " << response.Error;
            throw TException(Shared_->DbInfo) << response.Error;
        }

        Badbit_ = false;
        return std::move(response.Result);
    }

    bool THandle::WaitToFinishImpl() noexcept {
        try {
            // May be result is ready
            TWorkerShared::EStatus status = Shared_->Status.load(std::memory_order_acquire);
            if (status == TWorkerShared::Success) {
                return true;
            }
            if (status == TWorkerShared::Failed) {
                return false;
            }

            // Should wait
            const TInstant deadline = Shared_->StartTime + Shared_->QueryTimeout;
            if (!Shared_->ResponseFuture.Wait(deadline)) {
                return false;
            }

            return Shared_->ResponseFuture.ExtractValue().IsOk;
        } catch (const std::exception& e) {
            Log_.Debug() << DbInfo_ << " Got exception in waitToFinishImpl: " << e.what();
        }

        return false;
    }

    std::unique_ptr<TResult> THandle::CheckPingFinished() {
        const TInstant deadline = Shared_->StartTime + Shared_->QueryTimeout;

        if (Shared_->ResponseFuture.HasValue() || deadline < TInstant::Now()) {
            return WaitResult();
        }

        return {};
    }

    void THandle::WaitToFinish() noexcept {
        Badbit_ = !WaitToFinishImpl();
    }

    bool THandle::TryMakeClear() {
        // May be result is ready
        if (Shared_->Status.load() != TWorkerShared::Success) {
            return false;
        }

        Badbit_ = false;
        return true;
    }

    bool THandle::Bad() const {
        return Badbit_;
    }

    TString THandle::ErrMsgTimeout(TDuration timeout) const {
        TString msg = NUtils::CreateStr(
            "query timeout (",
            timeout.MilliSeconds(),
            " ms)");
        Log_.Debug() << DbInfo_ << " " << msg
                     << " (randid=" << Shared_->Randid << ") tid=" << Shared_->Tid;

        return msg;
    }

    THandle::THandle(const THandleSettings& settings,
                     THandleUnistatCtx unistat,
                     TDbPoolLog logger)
        : DefaultQueryTimeout_(settings.QueryTimeout)
        , Log_(logger)
        , DbInfo_(*settings.Dsn, settings.DbHost)
        , Shared_(std::make_shared<TWorkerShared>(
              settings.Dsn,
              settings.DbHost,
              DbInfo_,
              Log_,
              settings.ConnectionTimeout,
              settings.QueryTimeout,
              settings.FetchStatusOnPing,
              std::move(unistat),
              "db_w_" + settings.Dsn->Driver))
        , DefaultQueryOpts_(settings.DefaultQueryOpts)
        , ProcHolder_(Shared_)
    {
        Log_.Debug() << DbInfo_ << " Worker cstor entry (randid=" << Shared_->Randid << ")";
    }

    void THandle::Proc(std::shared_ptr<TWorkerShared> shared) {
        TThread::SetCurrentThreadName(shared->ThreadName.c_str());

        const TDestination& dsn(*shared->Dsn);
        TDbPoolLog log = shared->Log;
        shared->Tid = TThread::CurrentThreadNumericId();
        log.Debug() << shared->DbInfo
                    << " Worker thread entry (randid=" << shared->Randid << ") tid=" << shared->Tid;

        TInstant start = TInstant::Now();

        // Initialize: a) obtain database driver first; b) establish connection
        // with the DB. Notify the caller about success.
        //
        // If any of this fails, notify the caller about failure and wait for the
        // exit flag.
        //
        try {
            shared->Sql = IDriver::GetDrv(dsn.Driver);

            if (!shared->Sql->Connect(
                    shared->DbHost.Host,
                    shared->DbHost.Port,
                    dsn.User,
                    dsn.Passwd,
                    dsn.Db,
                    shared->ConnectTimeout,
                    shared->QueryTimeout,
                    dsn.Extended,
                    shared->FetchStatusOnPing))
            {
                throw yexception() << shared->Sql->Error();
            }

            if (!dsn.LocaleCmd.empty() && !shared->Sql->Query(dsn.LocaleCmd, shared->QueryTimeout)) {
                throw yexception() << "couldn't execute command <" << dsn.LocaleCmd
                                   << "> upon connect: '" << shared->Sql->Error() << "'";
            }

        } catch (const std::exception& e) {
            log.Debug() << shared->DbInfo
                        << " Worker thread connect failure: " << e.what()
                        << " (took " << HumanReadable(TInstant::Now() - start)
                        << ", randid=" << shared->Randid << ") tid=" << shared->Tid;
            shared->Status.store(TWorkerShared::Failed);

            shared->ConnectedPromise.SetValue(e.what());

            log.Debug() << shared->DbInfo
                        << " Worker thread couldn't start, exiting, randid=" << shared->Randid
                        << ") tid=" << shared->Tid;
            return;
        }

        // How long did it take to complete most recent operation
        TDuration t = TInstant::Now() - start;

        // mutex must be locked to avoid missing of notify in queryCond
        std::unique_lock lock(shared->Mutex);

        shared->Status.store(TWorkerShared::Success);
        shared->ConnectedPromise.SetValue(THandleInitError{});
        log.Debug() << shared->DbInfo << " Worker thread started (took " << HumanReadable(t)
                    << ", randid=" << shared->Randid << ") tid=" << shared->Tid;

        // From now on wait for query requests in a loop; individual
        // failures are not considered fatal; poll the exit flag.
        while (!shared->Exit.load(std::memory_order_acquire)) {
            shared->QueryCond.wait(lock);

            if (shared->Status.load(std::memory_order_relaxed) != TWorkerShared::InProgress) {
                continue;
            }

            TString requestId = std::move(shared->RequestId);
            NUtils::TRequestIdGuard g(&requestId);

            TResponse response;

            lock.unlock();

            {
                const char* logableQuery = dsn.Db.StartsWith("oauthdbshard")
                                               ? ""
                                               : shared->Query.Query().c_str();
                start = TInstant::Now();
                try {
                    // Is it lengthy or danagerous? We're unsure.
                    // So we catch exceptions and don't hold mutex, just in case.
                    // Here Result saves 'ready' time
                    response.Result = shared->IsPing
                                          ? shared->Sql->Ping(shared->QueryTimeout)
                                          : shared->Sql->Query(shared->Query, shared->QueryTimeout);
                    response.IsOk = (bool)response.Result && shared->Sql->ErrNum() == 0;

                    t = TInstant::Now() - start;

                    if (response.IsOk) {
                        if (t > TDuration::MilliSeconds(50)) {
                            log.Debug() << shared->DbInfo
                                        << " Worker thread: Long query: took " << HumanReadable(t)
                                        << ". Query='" << logableQuery
                                        << "' (randid=" << shared->Randid << ") tid=" << shared->Tid;
                        }
                    } else {
                        ++shared->Unistat.Counters->QueryError;
                        ++shared->Unistat.PoolCounters->QueryError;

                        response.Error = NUtils::CreateStr(
                            shared->IsPing ? "ping" : "query",
                            " failure: <",
                            shared->Sql->Error(),
                            ">[",
                            shared->Sql->ErrNum(),
                            ']');
                        log.Debug() << shared->DbInfo << " Worker thread " << response.Error
                                    << ". Query='" << logableQuery
                                    << "' (took " << HumanReadable(t)
                                    << ", randid=" << shared->Randid << ") tid=" << shared->Tid;
                    }
                } catch (const std::exception& e) {
                    t = TInstant::Now() - start;
                    response.Error = NUtils::CreateStr("query() exception: <", e.what(), '>');
                    log.Debug() << shared->DbInfo << " Worker thread " << response.Error
                                << ". Query='" << logableQuery
                                << "' (took " << HumanReadable(t)
                                << ", randid=" << shared->Randid << ") tid=" << shared->Tid;
                }
            }

            lock.lock();

            // order is important for nonblocking interface
            shared->IsPing = false;
            shared->Query = {};
            shared->Status.store(response.IsOk ? TWorkerShared::Success : TWorkerShared::Failed,
                                 std::memory_order_release);
            shared->ResponsePromise.SetValue(std::move(response));

            if (shared->Unistat.TimeStats) {
                shared->Unistat.TimeStats->Insert(t);
            }
            if (shared->Unistat.PoolTimeStats) {
                shared->Unistat.PoolTimeStats->Insert(t);
            }
        }

        shared->Sql.reset(); // to clear thread local resources
        log.Debug() << shared->DbInfo
                    << " Worker thread cancelled, exiting (randid=" << shared->Randid
                    << ") tid=" << shared->Tid;
    }

    THandle::TProcHolder::TProcHolder(std::shared_ptr<TWorkerShared> shared)
        : Shared_(shared)
        , Thread_(THandle::Proc, std::move(shared))
    {
    }

    THandle::TProcHolder::~TProcHolder() {
        {
            std::unique_lock lock(Shared_->Mutex);
            Shared_->Exit = true;
            Shared_->QueryCond.notify_one();
        }

        if (Shared_->Status.load(std::memory_order_acquire) == TWorkerShared::InProgress) {
            Thread_.detach();
        } else {
            Thread_.join();
        }
    }
}
