#include "timer_service.h"

#include <atomic>
#include <list>
#include <vector>

namespace quasar {
    namespace {
        std::atomic<uint32_t> globalTaskId{0};
    } // namespace

    struct TimerService::Core {
        const ExceptionHandler exceptionHandler;
        mutable std::mutex mutex;
        std::atomic<bool> shutdown{false};
        std::thread timerThread;
        SteadyConditionVariable wakeUpVar;

        struct TaskData {
            uint32_t taskId{0};
            bool singleShot{true};
            std::chrono::milliseconds interval;
            std::function<void()> callback;
            Lifetime::Tracker tracker;
            std::weak_ptr<ICallbackQueue> callbackQueue;

            std::chrono::steady_clock::time_point nextTimePoint;
        };
        std::list<TaskData> taskData;
        bool dirty{false};

        Core(TimerService::ExceptionHandler eh)
            : exceptionHandler(std::move(eh))
        {
            // Run thread after full initialization
            timerThread = std::thread(&TimerService::Core::workerLoop, this);
        }
        ~Core()
        {
            destroy();
            timerThread.join();
        }

        void destroy() {
            shutdown = true;
            wakeUpVar.notify_all();
        }

        void workerLoop()
        {
            try {
                struct FireItem {
                    std::shared_ptr<ICallbackQueue> callbackQueue;
                    std::function<void()> callback;
                    Lifetime::Tracker tracker;
                };
                std::vector<FireItem> fireItems;
                fireItems.reserve(128);
                std::unique_lock<std::mutex> lock(mutex);
                while (!shutdown.load()) {
                    auto now = std::chrono::steady_clock::now();
                    if (dirty) {
                        taskData.sort([](const auto& td1, const auto& td2) { return td1.nextTimePoint < td2.nextTimePoint; });
                        dirty = false;
                    }
                    for (auto it = taskData.begin(); it != taskData.end();) {
                        bool removeIt{false};
                        if (it->nextTimePoint <= now) {
                            if (auto cq = it->callbackQueue.lock()) {
                                dirty = true;
                                if (!it->tracker.lock()) {
                                    removeIt = true;
                                } else if (it->singleShot) {
                                    fireItems.push_back(FireItem{cq, std::move(it->callback), it->tracker});
                                    removeIt = true;
                                } else {
                                    fireItems.push_back(FireItem{cq, it->callback, it->tracker});
                                    it->nextTimePoint += it->interval;
                                }
                            } else {
                                removeIt = true;
                            }
                        } else {
                            // since all timers are sorted by time, we iterate only until the first one which does not need to be started
                            break;
                        }
                        if (removeIt) {
                            it = taskData.erase(it);
                        } else {
                            ++it;
                        }
                    }

                    if (!fireItems.empty()) {
                        lock.unlock();
                        for (auto& fireItem : fireItems) {
                            try {
                                fireItem.callbackQueue->add(std::move(fireItem.callback), std::move(fireItem.tracker));
                                fireItem.callbackQueue = nullptr;
                            } catch (...) {
                                if (exceptionHandler) {
                                    try {
                                        exceptionHandler(std::current_exception());
                                    } catch (...) {
                                    }
                                }
                            }
                        }
                        fireItems.clear();
                        lock.lock();
                    }

                    if (dirty) {
                        taskData.sort([](const auto& td1, const auto& td2) { return td1.nextTimePoint < td2.nextTimePoint; });
                        dirty = false;
                    }

                    if (taskData.empty()) {
                        wakeUpVar.wait(lock, [this]() {
                            return shutdown.load() || !taskData.empty();
                        });
                    } else {
                        auto time0 = std::chrono::steady_clock::now();
                        auto nextTimePoint = taskData.front().nextTimePoint;
                        wakeUpVar.wait_until(lock, nextTimePoint,
                                             [&] { return shutdown.load() || taskData.empty() || taskData.front().nextTimePoint != nextTimePoint; });
                        if (std::chrono::steady_clock::now() - time0 < std::chrono::milliseconds{1}) {
                            // Overheat! Force wait 1 msec
                            wakeUpVar.wait_for(lock, std::chrono::milliseconds{1}, [] { return false; });
                        }
                    }
                }
            } catch (...) {
                shutdown = true;
                if (exceptionHandler) {
                    try {
                        exceptionHandler(std::current_exception());
                    } catch (...) {
                    }
                }
            }
        }
    };

    class TimerServiceTask: public IPeriodicTask, public IDelayedTask {
    public:
        TimerServiceTask(std::weak_ptr<TimerService::Core> core, uint32_t taskId)
            : core_(core)
            , taskId_(taskId)
        {
        }

    public: // ITmer methods
        bool expired() const noexcept override {
            if (auto core = core_.lock()) {
                std::lock_guard<std::mutex> lock(core->mutex);
                for (const auto& td : core->taskData) {
                    if (td.taskId == taskId_) {
                        return td.tracker.expired() || td.callbackQueue.expired();
                    }
                }
            }
            return true;
        }

        bool restart(std::chrono::milliseconds interval) noexcept override {
            if (auto core = core_.lock()) {
                bool wakeUpNow{false};
                bool exist{false};
                {
                    std::lock_guard<std::mutex> lock(core->mutex);
                    for (auto it = core->taskData.begin(); it != core->taskData.end(); ++it) {
                        if (it->taskId == taskId_) {
                            it->interval = interval;
                            it->nextTimePoint = std::chrono::steady_clock::now() + interval;
                            exist = !it->tracker.expired() && !it->callbackQueue.expired();
                            if (it->nextTimePoint < core->taskData.front().nextTimePoint || it->taskId == core->taskData.front().taskId) {
                                wakeUpNow = true;
                            }
                            core->dirty = true;
                            if (!exist) {
                                core->taskData.erase(it);
                            }
                            break;
                        }
                    }
                }
                if (wakeUpNow) {
                    core->wakeUpVar.notify_all();
                }
                return exist;
            } else {
                return false;
            }
        }

        bool stop() noexcept override {
            if (auto core = core_.lock()) {
                bool wakeUpNow{false};
                bool exist{false};
                {
                    std::lock_guard<std::mutex> lock(core->mutex);
                    for (auto it = core->taskData.begin(); it != core->taskData.end(); ++it) {
                        if (it->taskId == taskId_) {
                            if (core->taskData.front().taskId == taskId_) {
                                wakeUpNow = true;
                            }
                            exist = !it->tracker.expired() && !it->callbackQueue.expired();
                            core->taskData.erase(it);
                            break;
                        }
                    }
                }
                if (wakeUpNow) {
                    core->wakeUpVar.notify_all();
                }
                return exist;
            } else {
                return false;
            }
        }

    private:
        const std::weak_ptr<TimerService::Core> core_;
        const uint32_t taskId_;
    };

    TimerService::TimerService(ExceptionHandler eh)
        : core_(std::make_shared<TimerService::Core>(std::move(eh)))
    {
    }

    void TimerService::shutdown() {
        core_->destroy();
    }

    std::shared_ptr<IPeriodicTask> TimerService::createPeriodicTask(std::chrono::milliseconds interval, std::function<void()> callback, const Lifetime::Tracker& tracker, std::weak_ptr<ICallbackQueue> callbackQueue) noexcept {
        return createTask(false, interval, std::move(callback), tracker, std::move(callbackQueue));
    }

    std::shared_ptr<IDelayedTask> TimerService::createDelayedTask(std::chrono::milliseconds interval, std::function<void()> callback, const Lifetime::Tracker& tracker, std::weak_ptr<ICallbackQueue> callbackQueue) noexcept {
        return createTask(true, interval, std::move(callback), tracker, std::move(callbackQueue));
    }

    std::shared_ptr<TimerServiceTask> TimerService::createTask(bool singleShot, std::chrono::milliseconds interval, std::function<void()> callback, Lifetime::Tracker tracker, std::weak_ptr<ICallbackQueue> callbackQueue) noexcept {
        try {
            if (!callback) {
                throw std::invalid_argument("argument \"callback\" can't be null");
            }

            auto now = std::chrono::steady_clock::now();
            uint32_t taskId = ++globalTaskId;
            Core::TaskData td{taskId, singleShot, interval, std::move(callback), std::move(tracker), std::move(callbackQueue), now + interval};

            bool wakeUpNow{false};
            {
                std::lock_guard<std::mutex> lock(core_->mutex);
                if (core_->taskData.empty() || core_->taskData.front().nextTimePoint >= td.nextTimePoint) {
                    core_->taskData.push_front(std::move(td));
                    wakeUpNow = true;
                } else {
                    core_->taskData.push_back(std::move(td));
                    core_->dirty = true;
                }
            }
            if (wakeUpNow) {
                core_->wakeUpVar.notify_all();
            }
            return std::make_shared<TimerServiceTask>(core_, taskId);
        } catch (...) {
            if (core_->exceptionHandler) {
                try {
                    core_->exceptionHandler(std::current_exception());
                } catch (...) {
                }
            }
        }
        return std::make_shared<TimerServiceTask>(std::weak_ptr<TimerService::Core>{}, 0);
    }

} // namespace quasar
