#include "ThreadScheduler.hpp"
#include "debug/TraceCall.hpp"
#include <algorithm>

namespace twitch {
ThreadScheduler::ThreadScheduler(NativePlatform& platform, const std::string& name, int threads)
    : m_platform(platform)
    , m_name(name.empty() ? "ThreadScheduler" : name)
    , m_run(true)
    , m_threadPoolSize(threads)
    , m_threadsRunning(0)
{
    for (int i = 0; i < m_threadPoolSize; i++) {
        m_threads.emplace_back(&ThreadScheduler::processQueue, this);
    }

    std::unique_lock<Mutex> lock(m_threadMutex);
    // block until thread starts (ensuring the thread can be joined in destruction)
    m_threadCondition.wait(lock, [this]() { return m_threadsRunning == m_threadPoolSize; });
}

ThreadScheduler::~ThreadScheduler()
{
    TraceCall trace(m_name + " destructor");

    {
        std::lock_guard<Mutex> lock(m_mutex);
        m_run = false;
        m_queueAvailable.notify_all();
        m_waitCondition.notify_all();
    }

    for (auto& thread : m_threads) {
        if (thread.joinable()) {
            // can happen if last shared thread scheduler reference is captured in a scheduled task
            if (thread.get_id() == std::this_thread::get_id()) {
                thread.detach();
            } else {
                thread.join();
            }
        }
    }

    // block until threads stop (in case threads are detached)
    std::unique_lock<Mutex> lock(m_threadMutex);
    m_threadCondition.wait(lock, [this]() { return m_threadsRunning == 0; });
}

std::shared_ptr<Cancellable> ThreadScheduler::schedule(Action action, Microseconds time, bool repeating)
{
    std::shared_ptr<Task> task = std::make_shared<Task>();
    task->action = action;
    task->interval = MediaTime(time);
    task->when = MediaTime::now() + time;
    task->repeating = repeating;
    task->owner = shared_from_this();

    {
        std::lock_guard<Mutex> lock(m_mutex);
        m_queue.push(task);
    }

    m_queueAvailable.notify_one();
    return task;
}

void ThreadScheduler::scheduleAndWait(Action action)
{
    auto threadId = std::this_thread::get_id();
    for (const auto& thread : m_threads) {
        if (thread.get_id() == threadId) {
            action();
            return;
        }
    }

    std::unique_lock<Mutex> lock(m_mutex);

    auto& task = m_waitTasks[threadId];

    if (!task) {
        task = std::make_shared<Task>();
        task->owner = shared_from_this();
    }
    task->action = std::move(action);
    task->when = MediaTime::now();
    task->complete = false;

    m_queue.push(task);
    m_queueAvailable.notify_one();
    m_waitCondition.wait(lock, [this, task]() { return task->complete || !m_run; });
}

void ThreadScheduler::cancel(std::shared_ptr<Task> task)
{
    std::unique_lock<Mutex> lock(m_mutex);
    if (task->cancelled) { // protects against waiting on an already cancelled task
        return;
    }

    task->cancelled = true;

    if (!m_queue.remove(task)) {
        // check we're not on the same thread the task is running on
        if (task->thread != std::this_thread::get_id()) {
            // wait for the task to finish
            m_waitCondition.wait(lock, [this, task]() { return task->complete || !m_run; });
        }
    } // else item removed from queue successfully
}

void ThreadScheduler::processQueue()
{
    {
        std::lock_guard<Mutex> lock(m_threadMutex);
        m_platform.setCurrentThreadName(m_name + "-" + std::to_string(m_threadsRunning));
        m_platform.onThreadCreated(std::this_thread::get_id(), m_name);
        m_threadsRunning++;
        if (m_threadsRunning == m_threadPoolSize) {
            m_threadCondition.notify_one(); // signal the thread is started/running
        }
    }

    while (m_run) {

        std::shared_ptr<Task> task; // task to run
        {
            std::unique_lock<Mutex> lock(m_mutex);
            // wait until the first task should be run (or until an item is queued)
            m_queueAvailable.wait(lock, [this]() { return !m_queue.empty() || !m_run; });

            if (!m_queue.empty()) {
                MediaTime now = MediaTime::now();
                MediaTime delta = m_queue.top()->when - now;

                if (delta <= MediaTime::zero()) {
                    // task is scheduled to run
                    task = m_queue.top();
                    m_queue.pop();
                }
                // wait until entry should be run
                else if (m_queueAvailable.wait_for(lock, delta.microseconds()) == CvStatus::timeout) {

                    if (!m_queue.empty() && MediaTime::now() >= m_queue.top()->when) {
                        task = m_queue.top();
                        m_queue.pop();
                    }
                }
            }

            if (task) {
                task->thread = std::this_thread::get_id();
            }
        }

        // run the task
        if (task) {
            task->run();

            if (task->owner.expired()) {
                break; // last task was the last referencer of this exit the thread
            }

            std::unique_lock<Mutex> lock(m_mutex);
            if (task->repeating && !task->cancelled) {
                task->when = MediaTime::now() + task->interval;
                m_queue.push(task);
            } else {
                task->complete = true;
                m_waitCondition.notify_all();
            }
        }
    }

    std::lock_guard<Mutex> lock(m_threadMutex);
    m_threadsRunning--;

    // Notify the destructor this is the last thread to stop
    if (m_threadsRunning == 0) {
        m_threadCondition.notify_one();
    }
}

void ThreadScheduler::Task::cancel()
{
    auto scheduler = owner.lock();
    if (scheduler) {
        scheduler->cancel(shared_from_this());
    } else {
        cancelled = true;
    }
}

void ThreadScheduler::Task::run()
{
    if (!cancelled && action) {
        action();
    }
}

bool ThreadScheduler::Queue::remove(const std::shared_ptr<Task>& task)
{
    auto& container = this->c;
    auto it = std::find(container.begin(), container.end(), task);
    if (it != container.end()) {
        std::swap(*it, container.back());
        container.pop_back();
        std::make_heap(container.begin(), container.end(), this->comp);
        return true;
    }
    return false;
}
}
