#include "smartqueue_ng.h"

#include <library/cpp/logger/global/global.h>

#include <util/generic/fastqueue.h>
#include <util/generic/utility.h>
#include <library/cpp/deprecated/atomic/atomic.h>
#include <util/system/guard.h>
#include <util/system/mutex.h>
#include <util/system/condvar.h>

#include <atomic>

namespace NUtil {
namespace {
enum class EThreadSignal {
    Data,
    NoData,
    WorkerStart,
    WorkerStop,
    Shutdown
};

struct TSmartTask {
    IObjectInQueue* Object;
    EThreadSignal Signal;

    inline operator bool() const {
        return Signal == EThreadSignal::Data;
    }
};

typedef ui32 TStat;
typedef i32 TStatDiff;

//
// The Queue itself. Transmits data from clients to workers.
//
class TQueueData {
public:
    TMutex Mutex;
    int StopSignals = 0;
    int NumWaiters = 0;
    std::atomic<TStat> StatCounter = 1;
    size_t MaxQueueSize = 0;
    TFastQueue<IObjectInQueue*> Data;
    TCondVar Event;

public:
    TSmartTask GetNextTask() {
        with_lock (Mutex)
            return GetNextTaskUnsafe();
    }

    TSmartTask WaitNextTask() {
        with_lock (Mutex)
            return WaitNextTaskUnsafe();
    }

    bool Push(IObjectInQueue* task) {
        with_lock (Mutex) {
            if (Y_UNLIKELY(Data.Size() >= MaxQueueSize)) {
                ++StatCounter;
                return false;
            }
            PushUnsafe(task);
            return true;
        }
    }

public:
    Y_FORCE_INLINE TSmartTask GetNextTaskUnsafe() {
        // this->Mutex must be locked
        TSmartTask task;
        if (Y_UNLIKELY(StopSignals)) {
            --StopSignals;
            task = {nullptr, EThreadSignal::WorkerStop};
        } else if (Data.Empty()) {
            task = {nullptr, EThreadSignal::NoData};
        } else {
            task = {Data.Pop(), EThreadSignal::Data};
        }
        return task;
    }

    Y_FORCE_INLINE void PushUnsafe(IObjectInQueue* task) {
        // this->Mutex must be locked
        if (NumWaiters > 0)
            Event.Signal();
        ++StatCounter;
        Data.Push(task);
    }

    Y_FORCE_INLINE TSmartTask WaitNextTaskUnsafe() {
        // this->Mutex must be locked
        TSmartTask result = GetNextTaskUnsafe();
        if (result.Signal == EThreadSignal::NoData) {
            ++NumWaiters;
            do {
                Event.Wait(Mutex);
                result = GetNextTaskUnsafe();
            } while (result.Signal == EThreadSignal::NoData);
            --NumWaiters;
        }
        return result;
    }
};

//
// Thread pool management context
//
class TQueueControl {
public:
    std::atomic<bool> IsActive;
    TMutex Mutex;
    int NumActive;
    int NumAdviced;
    TCondVar Event;
    TQueueData& Queue;

public:
    TQueueControl(TQueueData& queue)
        : IsActive(false)
        , NumActive(0)
        , NumAdviced(0)
        , Queue(queue)
    {
    }

    EThreadSignal OnThreadStart() {
        with_lock (Mutex) {
            ++NumActive;

            Y_ASSERT(IsActive);
            return IsActive ? EThreadSignal::NoData : EThreadSignal::Shutdown;
        }
    }

    void OnThreadStop() {
        with_lock (Mutex)
            --NumActive;
    }

    EThreadSignal WaitNewThread() {
        TGuard<TMutex> g(Mutex);
        --NumActive;
        while (NumActive >= NumAdviced && IsActive) {
            Event.Wait(Mutex);
        }
        ++NumActive;

        if (!IsActive)
            return EThreadSignal::Shutdown;

        return EThreadSignal::WorkerStart;
    }

public:
    void AdviceThreadsUnsafe(int threadsNum) {
        // this->Mutex must be locked
        if (threadsNum != NumAdviced) {
            if (threadsNum > NumAdviced) {
                IncThreads(threadsNum - NumAdviced);
            } else {
                DecThreads(NumAdviced - threadsNum);
            }
        }
    }

private:
    void IncThreads(size_t n) {
        // this->Mutex must be locked
        NumAdviced += n;

        with_lock (Queue.Mutex) {
            size_t k = Queue.StopSignals;
            if (k > n)
                k = n;
            Queue.StopSignals -= k;
            n -= k;
        }

        for (; n; --n) {
            Event.Signal(); //this->Event signal unblocks one or more WaitNewThread() calls (zero if none are waiting)
        }
    }

    void DecThreads(size_t n) {
        // this->Mutex must be locked
        NumAdviced -= n;

        with_lock (Queue.Mutex) {
            Queue.StopSignals += n;
            for (; n; --n) {
                Queue.Event.Signal(); //Queue.Event signal unblocks one or more WaitNextTask() calls (zero if none are waiting)
            }
        }
    }
};

//
// Helper class
//
template <typename TValue, size_t PrecisionLog2>
struct TFixedPoint {
    TValue Value;
    static constexpr size_t Precision = PrecisionLog2; // 1/32

    static constexpr float Epsilon = 1.0f / (1 << Precision);

    constexpr TFixedPoint()
        : Value(TStat())
    {
    }

    constexpr TFixedPoint(float vl)
        : Value(TStat(vl * (1 << Precision)))
    {
    }

    inline constexpr operator float() const {
        return Epsilon * Value;
    }

    inline size_t constexpr operator*(size_t a) const {
        return (a * Value) >> Precision;
    }
};

//
// An internal representation of the limits (calculated from Options and threadCount on start)
//
struct TQueueLimits {
    using TCoeff = TFixedPoint<TStat, 5>;

public:
    size_t MinThreads = 4;
    size_t MaxThreads = 4;
    size_t StepThreads = 4;
    TCoeff JobsPerWorker = 1.0f;
    TCoeff BurstJobsPerWorker = 1.0f;

    static constexpr float MinJobsPerWorker = TCoeff::Epsilon;

    inline bool IsFixed() const {
        return MinThreads == MaxThreads;
    }

    void SetLimits(const TSmartMtpQueueNg::TOptions& options, size_t newMaxThreads) {
        if (options.IsFixed()) {
            MinThreads = newMaxThreads;
            MaxThreads = newMaxThreads;
        } else {
            // lower MinThreads if appropriate
            size_t newMinThreads = Min(options.MinThreads, newMaxThreads);
            const size_t step = Max<size_t>(options.StepThreads, 1);
            while (newMinThreads > step && options.CalcMaxThreads(newMinThreads) > newMaxThreads)
                newMinThreads -= step;

            MinThreads = newMinThreads;
            MaxThreads = Max(newMinThreads, newMaxThreads);
        }

        StepThreads = options.StepThreads;
        if (!IsFixed()) {
            StepThreads = ClampVal<size_t>(StepThreads, 1, MaxThreads - MinThreads);
            JobsPerWorker = options.JobsPerWorker;
            BurstJobsPerWorker = Max(options.JobsPerWorker, options.BurstJobsPerWorker);
        } else {
            StepThreads = 1;
            JobsPerWorker = 1;
            BurstJobsPerWorker = 1;
        }

        CHECK_WITH_LOG(IsFixed() || JobsPerWorker >= MinJobsPerWorker);
    }
};

//
// Helper class to adjust TQueueAdviser latency to RPS
//
struct TPeriodTuner {
public:
    // all the Period values in this class are Log2(iterationsCount)

    // tuning boundaries
    static constexpr TStat InitialPeriod = 4;
    static constexpr TStat MinPeriod = 2;       //Advice every 4 requests at ~120 RPS
    static constexpr TStat MaxAdvicePeriod = 8; //Advice every 256 requests at ~4000 RPS
    static constexpr TStat MaxPeriod = 17;

    // desired advice frequency
    static constexpr ui32 AdviceIntervalMs = 30; // we would like to get advice every 30ms
    static constexpr TStat TunePeriod = 4;       // and update AdvicePeriod every 16 advices
    static constexpr TDuration TunerUpdateInterval = TDuration::MilliSeconds(AdviceIntervalMs * (1 << TunePeriod));

public:
    std::atomic<TStat> Mask = (1 << InitialPeriod) - 1; // TPeriodTuner result
    TStat AdvicePeriod = InitialPeriod;
    TStat Period = InitialPeriod;
    TInstant LastUpdateTime = TInstant::Zero();

    TPeriodTuner& operator=(const TPeriodTuner& other) {
        if (this != &other) {
            Mask.store(other.Mask.load());
            AdvicePeriod = other.AdvicePeriod;
            Period = other.Period;
            LastUpdateTime = other.LastUpdateTime;
        }
        return *this;
    }

public:
    Y_FORCE_INLINE bool AllowAdvice(ui32 iterNo) const {
        return (iterNo & Mask) == 0;
    }

    void Update() {
        // This function slowly tunes Period in a way that this->Update() is called roughly once in TunerUpdateInterval
        // when Period == MaxMatchPeriod is still not enough (>4000 RPS), it just lowers tuning frequency & report frequency

        const TInstant now = Now();
        const TInstant prev = LastUpdateTime;
        LastUpdateTime = now;
        if (!prev)
            return;

        static constexpr TDuration IncThreshold = TDuration::MicroSeconds(TunerUpdateInterval.MicroSeconds() / 3);
        static constexpr TDuration DecThreshold = TDuration::MicroSeconds(TunerUpdateInterval.MicroSeconds() * 5 / 4);
        if ((now - prev) > DecThreshold) {
            if (Period > MinPeriod)
                SetPeriod(now, Period - 1);
        } else if ((now - prev) < IncThreshold) {
            if (Period < MaxPeriod)
                SetPeriod(now, Period + 1);
        }
        //Cerr <<"    Tuner.Update: period=" << Period << ", diff=" << (now - prev).MicroSeconds() << "us"  << Endl;
    }

    inline void SetPeriod(const TInstant& now, TStat period) {
        LastUpdateTime = now;
        Period = period;
        AdvicePeriod = period < MaxAdvicePeriod ? period : MaxAdvicePeriod;
        Mask = (1 << AdvicePeriod) - 1;
    }
};

//
// Hi-Low queue logic (a device that calculates and advices the number of threads)
//
class TQueueAdviser final {
private:
    static constexpr TStat CheckPeriodLog2 = 2;  // 1 << (2 + 2) = do check every 16 calls of Add() ( 2 is init.value from Tuner)
    static constexpr TStat ReportPeriodLog2 = 6; // 1 << 6 = publish stats every 64 executions of Apply()
    static constexpr TStat DecAveragingLog2 = 4; // 1 << 4 = average 16 checks to decrement
    static constexpr TStat IncAveragingLog2 = 4; // 1 << 4 = average 16 checks to increment

private:
    using TStatRecord = TSmartMtpQueueNg::IStatHandler::TStats;

private:
    TPeriodTuner Tuner;
    TStat LastIterNo = 0;
    TStatDiff Sum = 0;
    TQueueLimits Limits;
    TSmartMtpQueueNg::TStatHandlerRef Stats;

public:
    Y_FORCE_INLINE static constexpr bool NeedCheck0(ui32 statCounter) {
        return (statCounter & ((1 << CheckPeriodLog2) - 1)) == 0;
    }

    Y_FORCE_INLINE bool NeedApply(ui32 statCounter) const {
        return NeedCheck0(statCounter) && Tuner.AllowAdvice(statCounter);
    }

    Y_FORCE_INLINE static constexpr bool NeedTune(ui32 iterNo, ui32 tunePeriod) {
        return (iterNo & ((1 << TPeriodTuner::TunePeriod << tunePeriod) - 1)) == 0;
    }

    Y_FORCE_INLINE static constexpr bool NeedReport(ui32 statCounter, ui32 extraLatency) {
        return (statCounter & ((1 << ReportPeriodLog2 << extraLatency) - 1)) == 0;
    }

    Y_FORCE_INLINE static constexpr size_t Unsign(int stat) {
        // stat-to-value conversion
        return stat < 0 ? 0 : (size_t)stat;
    }

    void Apply(TQueueControl& control, size_t iterNo) {
        TGuard<TMutex> g(control.Mutex);
        if (LastIterNo >= iterNo && LastIterNo - iterNo < (Max<TStat>() >> 1) /* 2nd condition is an overflow check */)
            return;
        LastIterNo = iterNo;

        size_t statInQueue, statInWorker, inFly;
        with_lock (control.Queue.Mutex) {
            // get stats from Queue
            const TQueueData& queue = control.Queue;
            statInQueue = queue.Data.Size();
            statInWorker = Unsign(control.NumActive - queue.NumWaiters);
            inFly = statInQueue + statInWorker;
        }

        const size_t currentAdvice = control.NumAdviced;
        const size_t targetInFly = Limits.JobsPerWorker * currentAdvice;
        Sum += inFly;
        Sum -= targetInFly;

        size_t advice = Advice(control, inFly, targetInFly); // suggest the optimal number of threads

        if (advice != currentAdvice || advice < Limits.StepThreads) {
            Sum = 0; //stop accumulating if the advice was changed or when the bottom is reached (hence no overflows)
        }

        advice = ClampVal(advice, Limits.MinThreads, Limits.MaxThreads);

        if (advice != currentAdvice) {
            control.AdviceThreadsUnsafe(advice); // change number of threads
        }

        if (NeedTune(iterNo, Tuner.Period)) {
            Tuner.Update();
        }

        if (Stats && NeedReport(iterNo, Tuner.Period)) {
            TStatRecord st;
            st.InQueue = statInQueue;
            st.InFly = inFly;
            st.TargetInFly = targetInFly;
            st.NumThreads = control.NumActive;
            st.NumAdviced = control.NumAdviced;
            st.NumBusyThreads = statInWorker;
            st.AdvicePeriod = Tuner.AdvicePeriod;
            st.nOperations = iterNo;
            g.Release();
            Report(st);
        }
    }

    void SetStat(TSmartMtpQueueNg::TStatHandlerRef& stats) {
        Stats = stats;
    }

    void Validate(const TQueueLimits& limits) {
        CHECK_WITH_LOG(limits.MaxThreads >= limits.MinThreads);
        CHECK_WITH_LOG(limits.IsFixed() || limits.StepThreads > 0);
        CHECK_WITH_LOG(limits.IsFixed() || limits.JobsPerWorker >= TQueueLimits::TCoeff::Epsilon);
    }

    void Init(const TQueueLimits& limits) {
        Validate(limits);
        Limits = limits;
        Tuner = TPeriodTuner();
    }

    void UpdateLimits(const TQueueLimits& limits) {
        CHECK_WITH_LOG(Limits.MaxThreads == limits.MaxThreads);
        CHECK_WITH_LOG(Limits.IsFixed() == limits.IsFixed());
        Validate(limits);
        Limits = limits;
    }

private:
    inline size_t Advice(TQueueControl& control, size_t inFly, size_t targetInFly) const {
        const size_t currentAdvice = control.NumAdviced;
        const size_t burstThreshold = Limits.BurstJobsPerWorker * (currentAdvice + Limits.StepThreads);
        size_t advice = currentAdvice;
        if (inFly > burstThreshold) {
            // Burst. Add as many threads as needed
            while (inFly > Limits.BurstJobsPerWorker * (advice + Limits.StepThreads))
                advice += Limits.StepThreads;
        } else {
            // Gradual tuning

            // equialent to:  const TStatDiff stat = (Sum < 0) ? Sum / (1 << DecAveragingLog2) : Sum / (1 << IncAveragingLog2);
            const TStat absStat = (Sum < 0) ? TStat(-Sum) >> DecAveragingLog2 : TStat(Sum) >> IncAveragingLog2;
            const TStatDiff stat = (Sum < 0) ? -TStatDiff(absStat) : TStatDiff(absStat);

            const size_t statInFly = Unsign(stat + (TStatDiff)targetInFly); // approximation
            if (Sum >= 0) {
                //Add some threads
                const size_t incThreshold = Limits.JobsPerWorker * (currentAdvice + Limits.StepThreads);
                if (statInFly > incThreshold)
                    advice += Limits.StepThreads;
            } else {
                // Suspend some threads
                const size_t decThreshold = Limits.JobsPerWorker * (currentAdvice - Min(currentAdvice, Limits.StepThreads));
                if (statInFly < decThreshold)
                    advice -= Min(Limits.StepThreads, advice);
            }
        }
        return advice;
    }

    void Report(TStatRecord& st) {
        if (Stats) {
            try {
                Stats->OnStats(st);
            } catch (...) {
                Y_VERIFY_DEBUG(false /* IStatHandler should not throw */);
            }
        }
    }
};

} // namespace anonymous


//
// TSmartThread & TImpl
//
class TSmartMtpQueueNg::TImpl final : public TSmartMtpQueueNg::IUpdatable {
private:
    class TSmartThread final: public IObjectInQueue {
    public:
        TSmartThread(TSmartMtpQueueNg::TImpl& parent)
            : Parent(parent)
            , Sleeping(true)
        {
        }

    private:
        Y_FORCE_INLINE static void ProcessTask(TThreadPool::TTsr& tsr, IObjectInQueue* task) noexcept {
            try {
                task->Process(tsr);
            } catch (...) {
                Cerr << "[mtp queue] " << CurrentExceptionMessage() << Endl;
            }
        }

        // IThreadFactory::IThreadAble
        void Process(void* /*tsp*/) override {
            IThreadPool::TTsr tsr(Parent.GetParent());

            TQueueData& queue = Parent.GetQueue();
            TQueueControl& control = Parent.GetControl();
            EThreadSignal sig = control.OnThreadStart();
            while (sig != EThreadSignal::Shutdown) {
                while (Sleeping) {
                    sig = control.WaitNewThread();
                    if (Y_LIKELY(sig == EThreadSignal::WorkerStart)) {
                        Sleeping = false;
                        break;
                    }
                    if (Y_UNLIKELY(sig == EThreadSignal::Shutdown)) {
                        Sleeping = false;
                        return;
                    }
                }

                TSmartTask task = queue.WaitNextTask();
                while (task) {
                    ProcessTask(tsr, task.Object);
                    task = queue.GetNextTask();
                }
                sig = task.Signal;

                if (sig == EThreadSignal::WorkerStop) {
                    Sleeping = true;
                }
            }

            Sleeping = false;
            control.OnThreadStop();
        }

    private:
        TSmartMtpQueueNg::TImpl& Parent;
        bool Sleeping;
    };

public:
    TImpl(TSmartMtpQueueNg* parent, const TString& name)
        : Queue(MakeHolder<TQueueData>())
        , Control(*Queue)
        , SlaveFixed(TThreadPool::TParams().SetFactory(parent->Pool()).SetThreadName(name))
        , Parent(parent)
    {
    }

    ~TImpl() {
        Stop();
    }

    static constexpr size_t QueueSizeLimit(size_t vl, size_t maxThreads) {
        return vl ? vl : maxThreads * 16;
    }

    static constexpr size_t MaxThreadLimit(size_t staticLimit, const TOptions& options) {
        size_t maxThreads = staticLimit;
        if (options.MaxThreadsLimit.Defined())
            maxThreads = Min(maxThreads, *options.MaxThreadsLimit);
        return maxThreads;
    }

    void Start(const TOptions& options, int maxThreadCount, size_t queueSizeLimit) {
        CHECK_WITH_LOG(maxThreadCount > 0);

        with_lock (Control.Mutex) {
            if (Control.IsActive)
                return;

            with_lock (Queue->Mutex) {
                // check & ensure consistency
                CHECK_WITH_LOG(Control.NumActive == 0);
                CHECK_WITH_LOG(Queue->Data.Empty());
                Queue->NumWaiters = 0;
                Queue->StopSignals = 0;
            }

            Limits.SetLimits(options, MaxThreadLimit(maxThreadCount, options));

            Adviser.Init(Limits);
            Queue->MaxQueueSize = QueueSizeLimit(queueSizeLimit, Limits.MaxThreads);
            Control.IsActive = true;
            Control.AdviceThreadsUnsafe(Limits.MinThreads); // was:MaxThreads, but it's better to start cold
        }

        ThreadPoolSize = Limits.MaxThreads;
        SlaveFixed.Start(ThreadPoolSize);
        for (size_t i = 0; i < ThreadPoolSize; ++i) {
            SlaveFixed.SafeAddAndOwn(THolder(new TSmartThread(*this)));
        }
        //Cerr << "Start TSmartMtpQueueNg, MinThreadCount=" << Limits.MinThreads << ", MaxThreadCount=" << Limits.MaxThreads << Endl;
    }

    bool ResetOptions(const TOptions& options) override {
        with_lock(Control.Mutex) {
            Limits.SetLimits(options, MaxThreadLimit(ThreadPoolSize, options));
            Adviser.UpdateLimits(Limits);
        }
        return true;
    }

    void Stop() {
        with_lock (Control.Mutex) {
            if (!Control.IsActive)
                return;

            Control.IsActive = false;
            Control.NumAdviced = 0;
            Control.Event.BroadCast();

            with_lock (Queue->Mutex) {
                Queue->StopSignals = Limits.MaxThreads;
                Queue->Event.BroadCast();
            }
        }

        SlaveFixed.Stop();
    }

    bool Add(IObjectInQueue* task) {
        VERIFY_WITH_LOG(task, "New task should not be null");

        if (!Control.IsActive) // deliberately unsafe
            return false;

        const bool added = Queue->Push(task);

        TStat iterNo = Queue->StatCounter; // deliberately unsafe

        if (Adviser.NeedApply(iterNo)) {
            Adviser.Apply(Control, iterNo);
        }

        return added;
    }

    size_t GetThreadCount() {
        with_lock (Control.Mutex)
            return Control.NumActive;
    }

    size_t GetQueueSize() const {
        with_lock (Queue->Mutex)
            return Queue->Data.Size();
    }

    void SetStatHandler(TStatHandlerRef& st) {
        with_lock (Control.Mutex)
            Adviser.SetStat(st);
    }

    // Interface for TSmartThread
    TSmartMtpQueueNg* GetParent() {
        return Parent;
    }

    TQueueData& GetQueue() {
        return *Queue;
    }

    TQueueControl& GetControl() {
        return Control;
    }

private:
    THolder<TQueueData> Queue;
    TQueueLimits Limits;
    TQueueControl Control;
    TQueueAdviser Adviser;
    size_t ThreadPoolSize = 0;

    TThreadPool SlaveFixed;

    TSmartMtpQueueNg* Parent;
};

TSmartMtpQueueNg::TSmartMtpQueueNg(const TSmartMtpQueueNg::TOptions& options)
    : Impl(MakeHolder<TImpl>(this, options.ThreadName))
    , Options(options)
{
}

TSmartMtpQueueNg::TSmartMtpQueueNg(IThreadFactory* pool, const TSmartMtpQueueNg::TOptions& options)
    : TThreadFactoryHolder(pool)
    , Impl(MakeHolder<TImpl>(this, options.ThreadName))
    , Options(options)
{
}

TSmartMtpQueueNg::~TSmartMtpQueueNg() {
}

void TSmartMtpQueueNg::SetStatHandler(TSmartMtpQueueNg::TStatHandlerRef& stats) {
    CHECK_WITH_LOG(Impl);
    Impl->SetStatHandler(stats);
}

void TSmartMtpQueueNg::Start(size_t maxThreadCount, size_t queueSizeLimit) {
    CHECK_WITH_LOG(Impl);

    Impl->Start(Options, maxThreadCount, queueSizeLimit);
}

void TSmartMtpQueueNg::Stop() noexcept {
    CHECK_WITH_LOG(Impl);
    Updatable.Reset();
    Impl->Stop();
}

THandle<TSmartMtpQueueNg::IUpdatable> TSmartMtpQueueNg::GetUpdatable() {
    if (!Updatable)
        Updatable = THandle<IUpdatable>(Impl.Get());

    return Updatable.AsHandle();
}

size_t TSmartMtpQueueNg::Size() const noexcept {
    CHECK_WITH_LOG(Impl);
    return Impl->GetQueueSize();
}

bool TSmartMtpQueueNg::Add(IObjectInQueue* obj) {
    CHECK_WITH_LOG(Impl);
    return Impl->Add(obj);
}

size_t TSmartMtpQueueNg::ThreadCount() {
    CHECK_WITH_LOG(Impl);
    return Impl->GetThreadCount();
}

} // namespace NUtil
