#include "asio_callback_pool.h"
#include <yandex_io/libs/base/utils.h>
#include <yandex_io/libs/logging/logging.h>

using namespace quasar;
using namespace quasar::ipc::detail::asio_ipc;

namespace {
    const std::thread::id NO_WORKER;
} // namespace

class AsioCallbackPool::Controller: public IAsioCallbackController {
public:
    Controller(AsioCallbackPool* asioCallbackPool, std::shared_ptr<std::atomic<size_t>> dirty)
        : asioCallbackPool_(asioCallbackPool)
        , dirty_(std::move(dirty))
    {
    }

    void notifyDestroy() override {
        dirty_->fetch_add(1);
    }

    void notifyTask(std::weak_ptr<AsioCallbackQueue> asioCallbackQueue) override {
        std::lock_guard lock(asioCallbackPool_->mutex_);
        asioCallbackPool_->orderQueue_.push_back(asioCallbackQueue);
        asioCallbackPool_->wakeUpCv_.notify_one();
    }

    void scheduleTask(std::weak_ptr<AsioCallbackQueue> asioCallbackQueue, std::function<void()> callback, std::chrono::milliseconds timeOut, Lifetime::Tracker tracker) override {
        if (asioCallbackPool_->timerService_) {
            asioCallbackPool_->timerService_->createDelayedTask(timeOut, std::move(callback), tracker, asioCallbackQueue);
        } else {
            throw std::runtime_error("No TimerService in AsioCallbackPool \"" + asioCallbackPool_->name_ + "\"");
        }
    }

private:
    AsioCallbackPool* asioCallbackPool_;
    std::shared_ptr<std::atomic<size_t>> dirty_;
};

AsioCallbackPool::AsioCallbackPool(std::string name, size_t threadCount, std::shared_ptr<ITimerService> timerService)
    : name_(std::move(name))
    , threadCount_(threadCount > 0 ? threadCount : std::thread::hardware_concurrency())
    , timerService_(std::move(timerService))
    , dirty_(std::make_shared<std::atomic<size_t>>(0))
    , controller_(std::make_shared<Controller>(this, dirty_))
{
    callbackQueues_.reserve(32);
    workers_.reserve(threadCount_);
    for (size_t i = 0; i < threadCount_; ++i) {
        workers_.push_back(std::thread(&AsioCallbackPool::workerLoop, this, i));
    }
}

AsioCallbackPool::~AsioCallbackPool() {
    destroy();
}

std::shared_ptr<AsioCallbackQueue> AsioCallbackPool::createAsioCallbackQueue(std::string name) {
    std::lock_guard lock(mutex_);
    if (shutdown_) {
        return nullptr;
    }
    auto cq = std::make_shared<AsioCallbackQueue>(std::move(name), controller_);
    callbackQueues_.push_back(cq);
    return cq;
}

size_t AsioCallbackPool::size() {
    std::lock_guard lock(mutex_);
    return orderQueue_.size();
}

void AsioCallbackPool::destroy() {
    decltype(callbackQueues_) tmpCallbackQueues;
    {
        std::lock_guard lock(mutex_);
        if (shutdown_) {
            return;
        }
        std::swap(tmpCallbackQueues, callbackQueues_);
        shutdown_ = true;
    }

    for (auto& cq : tmpCallbackQueues) {
        if (auto callbackQueue = cq.lock()) {
            callbackQueue->shutdown();
        }
    }

    for (auto& cq : tmpCallbackQueues) {
        if (auto callbackQueue = cq.lock()) {
            callbackQueue->destroy();
        }
    }

    stopped_ = true;
    wakeUpCv_.notify_all();
    for (size_t i = 0; i < threadCount_; ++i) {
        if (workers_[i].joinable()) {
            workers_[i].join();
        }
    }

    workers_.clear();
}

void AsioCallbackPool::workerLoop(size_t n) {
    const std::string logPrefix = makeString("AsioCallbackPool ", name_, "[", n, "]: ");
    size_t statTaskCount = 0;
    try {
        YIO_LOG_INFO(logPrefix << "Start worker thread");
        std::weak_ptr<AsioCallbackQueue> lastCallbackQueue;
        do {
            bool fContinue = false;
            bool fNoTask = false;
            if (auto callbackQueue = lastCallbackQueue.lock()) {
                auto task = callbackQueue->pop();
                if (task.callback) {
                    auto restIteration = --callbackQueue->iterations();
                    fContinue = (restIteration > 0);
                    try {
                        if (auto tlock = task.tracker.lock()) {
                            ++statTaskCount;
                            task.callback();
                        }
                        task.callback = nullptr; // clear lambda capturing in "try/catch"
                    } catch (const std::exception& ex) {
                        YIO_LOG_WARN(logPrefix << callbackQueue->name() << " raise exception: " << ex.what());
                    } catch (...) {
                        YIO_LOG_WARN(logPrefix << callbackQueue->name() << " raise unknown exception");
                    }
                } else {
                    fNoTask = true;
                }
            }
            if (!fContinue) {
                auto hasUnprocessedIterations = [&] {
                    if (auto callbackQueue = lastCallbackQueue.lock()) {
                        if (callbackQueue->iterations() > 0) {
                            return true;
                        }
                        // Reset current worker from last queue
                        callbackQueue->worker() = NO_WORKER;
                    }
                    return false;
                };
                auto cleanupDeadCallbackQueues = [&] {
                    auto end = std::remove_if(callbackQueues_.begin(), callbackQueues_.end(), [](const auto& cq) { return cq.expired(); });
                    auto size = callbackQueues_.size() - std::distance(end, callbackQueues_.end());
                    callbackQueues_.resize(size);
                    dirty_->store(0);
                };

                auto selectNextCallbackQueue = [&] {
                    if (hasUnprocessedIterations()) {
                        return lastCallbackQueue;
                    }
                    std::weak_ptr<AsioCallbackQueue> nextCq;
                    while (orderQueue_.size()) {
                        auto scq = orderQueue_.front().lock();
                        if (!scq) {
                            orderQueue_.pop_front();
                            continue;
                        } else if (scq->worker() != NO_WORKER) {
                            ++scq->iterations();
                            orderQueue_.pop_front();
                            continue;
                        } else if (nextCq.expired()) {
                            nextCq = scq;
                            scq->worker() = std::this_thread::get_id();
                            ++scq->iterations();
                            orderQueue_.pop_front();
                            continue;
                        }
                        break;
                    }
                    return nextCq;
                };

                // I. All planned iterations have been completed. We are trying to enter the order
                //    queue for the next task.
                if (std::unique_lock lock{mutex_, std::try_to_lock}) {
                    lastCallbackQueue = selectNextCallbackQueue();
                    fNoTask = false;
                }

                // II. If it was not possible to "quickly" capture the mutex and maintain the current
                //     queue with tasks, then we continue their execution.

                if (fNoTask || lastCallbackQueue.expired()) {
                    // III. If there is nothing to do, then we get up waiting for the next task
                    std::unique_lock lock(mutex_);
                    if (!hasUnprocessedIterations()) {
                        constexpr size_t minDirty = 8;
                        constexpr size_t maxDirty = 128;
                        constexpr size_t partDirty = 4;
                        auto dirty = dirty_->load();
                        if (dirty >= minDirty && (dirty >= maxDirty || dirty * partDirty > callbackQueues_.size())) {
                            cleanupDeadCallbackQueues();
                        }
                        wakeUpCv_.wait(lock, [&] { return stopped_ || orderQueue_.size(); });
                        lastCallbackQueue = selectNextCallbackQueue();
                    }
                }
            }
        } while (!stopped_);
    } catch (const std::exception& ex) {
        YIO_LOG_WARN(logPrefix << "Shutdown worker thread due exception: " << ex.what());
    } catch (...) {
        YIO_LOG_WARN(logPrefix << "Shutdown worker thread due exception unknown exception");
    }
    YIO_LOG_INFO(logPrefix << "Stop worker thread. Total " << statTaskCount << " task executed.");
}
