#include <yandex/maps/wiki/threadutils/executor.h>

#include <maps/libs/common/include/exception.h>

#include <atomic>
#include <condition_variable>
#include <exception>
#include <list>
#include <memory>
#include <mutex>

namespace maps::wiki {

class Executor::Impl
{
public:
    void addTask(ThreadPool::Functor task);

    void executeAll(ThreadPool& pool);
    void executeAllInThreads(ThreadPool& pool);

private:
    struct OneshotWrapperData
    {
        OneshotWrapperData(ThreadPool::Functor functor)
            : functorPtr(new ThreadPool::Functor(std::move(functor)))
        {
            executedFlag.clear();
        }

        std::unique_ptr<ThreadPool::Functor> functorPtr;
        std::atomic_flag executedFlag;
    };
    void oneshotWrapper(const std::shared_ptr<OneshotWrapperData>& data);

    void awaitCompletion();

    bool hasFailed() const;
    void checkFailure() const;

    void onTaskCompleted();

private:
    std::list<ThreadPool::Functor> tasks_;
    size_t completedCount_{0};
    std::mutex waitMutex_;
    std::condition_variable completedCond_;
    std::atomic<bool> hasFailed_{false};
    std::exception_ptr failureReason_;
};

Executor::Executor()
    : impl_(new Impl)
{ }

Executor::~Executor() = default;

void Executor::addTask(ThreadPool::Functor task)
{ impl_->addTask(std::move(task)); }

void Executor::executeAll(ThreadPool& pool)
{ impl_->executeAll(pool); }

void Executor::executeAllInThreads(ThreadPool& pool)
{ impl_->executeAllInThreads(pool); }

void Executor::Impl::addTask(ThreadPool::Functor task)
{
    tasks_.emplace_back(std::bind(
            &Executor::Impl::oneshotWrapper,
            this,
            std::make_shared<OneshotWrapperData>(std::move(task))));
}

void Executor::Impl::executeAll(ThreadPool& pool)
{
    bool first = true;
    for (ThreadPool::Functor& task : tasks_) {
        if (first) {
            // Reserve task for current thread
            first = false;
        } else {
            pool.push(task);
        }
    }

    for (ThreadPool::Functor& task : tasks_) {
        task();
    }

    awaitCompletion();
    checkFailure();
}

void Executor::Impl::executeAllInThreads(ThreadPool& pool)
{
    for (ThreadPool::Functor& task : tasks_) {
        REQUIRE(pool.push(task), "Thread pool unavailable");
    }

    awaitCompletion();
    checkFailure();
}

void Executor::Impl::oneshotWrapper(const std::shared_ptr<OneshotWrapperData>& data)
{
    if (data->executedFlag.test_and_set()) {
        return;
    }

    if (!hasFailed()) {
        try {
            (*data->functorPtr)();
        } catch (...) {
            if (!hasFailed_.exchange(true)) {
                failureReason_ = std::current_exception();
            }
        }
    }
    data->functorPtr.reset();

    onTaskCompleted();
}

void Executor::Impl::awaitCompletion()
{
    std::unique_lock<std::mutex> lock(waitMutex_);
    const size_t expectedCount = tasks_.size();
    while (completedCount_ < expectedCount) {
        completedCond_.wait(lock);
    }
    ASSERT(completedCount_ == expectedCount);
}

bool Executor::Impl::hasFailed() const
{ return hasFailed_.load(); }

void Executor::Impl::checkFailure() const
{
    if (hasFailed()) {
        std::rethrow_exception(failureReason_);
    }
}

void Executor::Impl::onTaskCompleted()
{
    std::lock_guard<std::mutex> lock(waitMutex_);
    ++completedCount_;
    completedCond_.notify_all();
}

} // namespace maps::wiki
