#include "threaded_executor.h"

#include <util/system/info.h>
#include <util/system/thread.h>

#include <util/generic/hash.h>

using namespace NYasm::NCommon;

namespace {
    class TSimpleJob: public IObjectInQueue {
    public:
        TSimpleJob(TOffloadedTask::TRef&& task)
            : Task(std::move(task))
        {
        }

        void Process(void*) override {
            THolder<TSimpleJob> self(this);
            try {
                (*Task)();
            } catch (...) {
                Cerr << CurrentExceptionMessage() << Endl;
                return;
            }
        }

    private:
        THolder<TOffloadedTask> Task;
    };

    class TJob: public IObjectInQueue {
    public:
        TJob(TOffloadedTask::TRef&& task)
            : Task(std::move(task))
            , Promise(NThreading::NewPromise<void>())
        {
        }

        void Process(void*) override {
            THolder<TJob> self(this);
            try {
                (*Task)();
            } catch (...) {
                Promise.SetException(CurrentExceptionMessage());
                return;
            }
            Promise.SetValue();
        }

        TBaseThreadedExecutor::TPromise& GetPromise() {
            return Promise;
        }

        TBaseThreadedExecutor::TFuture GetFuture() {
            return Promise.GetFuture();
        }

    private:
        THolder<TOffloadedTask> Task;
        TBaseThreadedExecutor::TPromise Promise;
    };

    class TFuncExecutor: public TOffloadedTask {
    public:
        TFuncExecutor(const std::function<void()>& func)
            : Func(func)
        {
        }

        void operator()() override {
            Func();
        }

    private:
        std::function<void()> Func;
    };

    class TNamedThreadFactories: public TNonCopyable {
        Y_DECLARE_SINGLETON_FRIEND();
    public:
        static TNamedThreadFactories& Get() {
            return *SingletonWithPriority<TNamedThreadFactories, 100000>();
        }

        IThreadFactory* GetThreadFactory(const TString& name) {
            IThreadFactory* threadFactory = nullptr;

            TGuard<TMutex> guard(Mutex);
            auto it(ThreadFactories.find(name));
            if (it.IsEnd()) {
                auto* newThreadFactory = new TNamedThreadFactory(name);
                ThreadFactories[name].Reset(newThreadFactory);
                threadFactory = newThreadFactory;
            } else {
                threadFactory = it->second.Get();
            }

            Y_ASSERT(threadFactory != nullptr);
            return threadFactory;
        }

    private:
        // Based on TSystemThreadFactory from util/thread/factory.cpp
        class TNamedThreadFactory: public IThreadFactory {
        public:
            class TNamedThread: public IThread {
            public:
                TNamedThread(const TString& name)
                    : Name_(name)
                {
                }

                ~TNamedThread() override {
                    if (Thr_) {
                        Thr_->Detach();
                    }
                }

                void DoRun(IThreadAble* func) override {
                    TThread::TParams params(ThreadProc, func);
                    if (!Name_.empty()) {
                        params.SetName(Name_);
                    }
                    Thr_.Reset(new TThread(params));

                    Thr_->Start();
                }

                void DoJoin() noexcept override {
                    if (!Thr_) {
                        return;
                    }

                    Thr_->Join();
                    Thr_.Destroy();
                }

            private:
                static void* ThreadProc(void* func) {
                    ((IThreadAble*)(func))->Execute();

                    return nullptr;
                }

            private:
                TString Name_;
                THolder<TThread> Thr_;
            };

            inline TNamedThreadFactory(const TString& name)
                : Name(name)
            {
            }

            IThread* DoCreate() override {
                return new TNamedThread(Name);
            }

        private:
            TString Name;
        };

        THashMap<TString, THolder<TNamedThreadFactory>> ThreadFactories;
        TMutex Mutex;
    };

    TMutex& SystemThreadPoolMutex() {
        return *SingletonWithPriority<TMutex, 100000>();
    }
}

void TBaseThreadedExecutor::Start(const TString& name, size_t queueSizeLimit) {
    auto* threadFactory(TNamedThreadFactories::Get().GetThreadFactory(name));
    Y_ASSERT(threadFactory != nullptr);

    /* TWorkStealingMtpQueue::Start() uses global systemThreadPool.
        So save previous systemThreadPool and set it to our named thread pool.
        Restore initial systemThreadPool after queue start.
        Defend assigning to global systemThreadPool with mutex. */

    TGuard<TMutex> guard(SystemThreadPoolMutex());

    auto* prevThreadFactory = SystemThreadFactory();
    SetSystemThreadFactory(threadFactory);

    Queue_->Init(queueSizeLimit);
    Queue_->Start(Size_, 0);

    SetSystemThreadFactory(prevThreadFactory);
}

TBaseThreadedExecutor::TBaseThreadedExecutor(const TString& poolName, size_t poolSize, size_t queueSizeLimit)
    : Queue_(new TWorkStealingMtpQueue())
    , Size_(poolSize ? poolSize : Max(NSystemInfo::NumberOfCpus(), 2UL))
{
    Start(poolName, queueSizeLimit);
}

TBaseThreadedExecutor::~TBaseThreadedExecutor() {
    Stop();
}

TBaseThreadedExecutor::TFuture TBaseThreadedExecutor::Add(TOffloadedTask::TRef&& task, bool major) {
    auto job(MakeHolder<TJob>(std::move(task)));
    auto future_ = job->GetFuture();
    bool success = false;
    if (Queue_ && major) {
        success = Queue_->Add(job.Get());
    } else if (Queue_) {
        success = Queue_->MinorJobQueue()->Push(job.Get());
    }
    if (success) {
        Y_UNUSED(job.Release());
    } else {
        job->GetPromise().SetException("Can not add function to queue");
    }
    return future_;
}

TBaseThreadedExecutor::TFuture TBaseThreadedExecutor::Add(const std::function<void()>& func, bool major) {
    TOffloadedTask::TRef task(MakeHolder<TFuncExecutor>(func));
    return Add(std::move(task), major);
}

bool TBaseThreadedExecutor::AddAndForget(const std::function<void()>& func, bool major) {
    return AddAndForget(MakeHolder<TFuncExecutor>(func), major);
}

bool TBaseThreadedExecutor::AddAndForget(TOffloadedTask::TRef&& task, bool major) {
    auto job = MakeHolder<TSimpleJob>(std::move(task));
    bool success = false;
    if (Queue_ && major) {
        success = Queue_->Add(job.Get());
    } else if (Queue_) {
        success = Queue_->MinorJobQueue()->Push(job.Get());
    }
    if (success) {
        Y_UNUSED(job.Release());
    }
    return success;
}

void TBaseThreadedExecutor::Stop() {
    if (Queue_) {
        Queue_->Stop();
        Queue_.Reset();
    }
}
