#include "core/eventqueue.hpp"

#include <assert.h>

#include <chrono>

namespace {
uint64_t GetSystemTimeMilliseconds() {
  return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch())
    .count();
}
}  // namespace

EventQueue::EventQueue() : mCurrentTaskId(1) {}

EventQueue::~EventQueue() {
  assert(mQueue.empty());
}

TaskId EventQueue::InsertTask(TaskParams&& taskParams) {
  assert(taskParams.taskFunc != nullptr);
  if (taskParams.taskFunc == nullptr) {
    return 0;
  }

  TaskId taskId;
  std::string taskName = taskParams.taskName;

  Task task;
  task.taskFunc = std::move(taskParams.taskFunc);
  task.taskName = std::move(taskParams.taskName);
  task.invocationTimestamp = GetSystemTimeMilliseconds() + taskParams.delayMilliseconds.count();

  {
    std::unique_lock<std::mutex> lock(mMutex);
    taskId = mCurrentTaskId;
    mCurrentTaskId++;
    task.taskId = taskId;

    auto it = std::upper_bound(mQueue.begin(), mQueue.end(), task);
    mQueue.insert(it, std::move(task));
  }

  // Notify that a task was added
  mCondition.notify_one();

  return taskId;
}

bool EventQueue::RemoveTask(TaskId taskId) {
  bool taskRemoved = false;

  {
    std::unique_lock<std::mutex> lock(mMutex);

    auto queueIt =
      std::find_if(mQueue.begin(), mQueue.end(), [taskId](const Task& task) { return task.taskId == taskId; });

    if (queueIt != mQueue.end()) {
      if (queueIt->taskFunc != nullptr) {
        // Move cancelled tasks to be run in WaitForEvent() to ensure destructors are called in the right thread.
        mCancelledTasks.emplace_back(std::move(queueIt->taskFunc));
        queueIt->taskFunc = nullptr;
        taskRemoved = true;
      }
    }
  }

  if (taskRemoved) {
    // Notify that a task was removed
    mCondition.notify_one();
  }

  return taskRemoved;
}

void EventQueue::Clear() {
  {
    std::unique_lock<std::mutex> lock(mMutex);
    mQueue.clear();
  }

  // Notify that the queue was cleared
  mCondition.notify_one();
}

void EventQueue::WaitForEvent() {
  bool ranTask = false;

  while (!ranTask) {
    // Clear out cancelled tasks and their associated lambdas - with any possible destructors.
    // We do this in WaitForEvent() and not in RemoveTask() to ensure destructors are called in the right thread.
    std::vector<TaskFunc> tasksToClear;
    {
      std::unique_lock<std::mutex> lock(mMutex);
      tasksToClear.swap(mCancelledTasks);
    }

    tasksToClear.clear();

    Task task;

    {
      std::unique_lock<std::mutex> lock(mMutex);

      uint64_t currentTime = GetSystemTimeMilliseconds();

      auto it = mQueue.begin();
      if (it != mQueue.end()) {
        if (it->invocationTimestamp <= currentTime) {
          // Remove the task from the EventQueue, and save the taskFunc to run outside the lock
          if (it->taskFunc != nullptr) {
            task = std::move(*it);
          }
          mQueue.erase(it);
        } else {
          // Wait until the first event is ready to fire
          uint64_t waitMilliseconds = it->invocationTimestamp - currentTime;
          mCondition.wait_for(lock, std::chrono::milliseconds(waitMilliseconds));
        }
      } else {
        mCondition.wait(lock);
      }
    }

    if (task.taskFunc != nullptr) {
      task.taskFunc();
      task.taskFunc = nullptr;
      ranTask = true;
    }
  }
}

bool EventQueue::WaitForEventWithTimeout(uint64_t timeout) {
  uint64_t currentTime = GetSystemTimeMilliseconds();
  if (currentTime >= std::numeric_limits<uint64_t>::max() - timeout) {
    // CurrentTime + timeout will overflow or hit the max, so we will wait without a timeout
    WaitForEvent();
    return true;
  }

  uint64_t endTime = currentTime + timeout;

  bool ranTask = false;
  bool forceTry = true;

  // Loop while there's still time remaining before the timeout
  // Also always try if it's the first time OR the last try we hit a cancelled task
  while (currentTime <= endTime || forceTry) {
    forceTry = false;

    // Clear out cancelled tasks and their associated lambdas - with any possible destructors.
    // We do this in WaitForEvent() and not in RemoveTask() to ensure destructors are called in the right thread.
    std::vector<TaskFunc> tasksToClear;
    {
      std::unique_lock<std::mutex> lock(mMutex);
      tasksToClear.swap(mCancelledTasks);
    }

    tasksToClear.clear();

    Task task;

    {
      std::unique_lock<std::mutex> lock(mMutex);

      currentTime = GetSystemTimeMilliseconds();

      auto it = mQueue.begin();
      if (it != mQueue.end()) {
        if (it->invocationTimestamp <= currentTime) {
          // Remove the task from the EventQueue, and save the taskFunc to run outside the lock
          if (it->taskFunc != nullptr) {
            task = std::move(*it);
          } else {
            // Force another loop iteration since we're just cleaning up a cancelled task in this iteration
            forceTry = true;
          }
          mQueue.erase(it);
        } else {
          // Wait difference from now until endTime OR now until first timestamp (whichever is shorter)
          if (endTime > currentTime) {
            uint64_t waitMilliseconds = endTime - currentTime;
            waitMilliseconds = std::min(waitMilliseconds, it->invocationTimestamp - currentTime);
            mCondition.wait_for(lock, std::chrono::milliseconds(waitMilliseconds));
          } else {
            // The timeout has been reached
            break;
          }
        }
      } else {
        // Queue is empty, so wait the difference from now until endTime
        if (endTime > currentTime) {
          uint64_t waitMilliseconds = endTime - currentTime;
          mCondition.wait_for(lock, std::chrono::milliseconds(waitMilliseconds));
        } else {
          // The timeout has been reached
          break;
        }
      }
    }

    if (task.taskFunc != nullptr) {
      task.taskFunc();
      task.taskFunc = nullptr;
      ranTask = true;
      break;
    }

    currentTime = GetSystemTimeMilliseconds();
  }

  return ranTask;
}
