#include "callback_queue.h"

#include <util/system/yassert.h>

#include <sstream>

using namespace quasar;
using namespace std::chrono;

CallbackQueue::CallbackQueue(size_t maxSize, Callback onOverflow, Callback onBecomeNormal)
    : queue_(maxSize, std::move(onOverflow), std::move(onBecomeNormal))
{
    queueThread_ = std::thread(&CallbackQueue::mainLoop, this);
}

std::string CallbackQueue::name() const {
    std::stringstream ss;
    ss << queueThread_.get_id();
    return ss.str();
}

void CallbackQueue::add(std::function<void()> callback) {
    if (!callback) {
        throw std::runtime_error("Nullptr pushed");
    }

    if (stopped_) {
        return;
    }

    queue_.push(std::move(callback));
}

void CallbackQueue::add(std::function<void()> callback, Lifetime::Tracker tracker)
{
    add(makeSafeCallback(std::move(callback), std::move(tracker)));
}

bool CallbackQueue::tryAdd(std::function<void()> callback) {
    if (!callback) {
        throw std::runtime_error("Nullptr pushed");
    }

    if (stopped_) {
        return false;
    }

    return queue_.tryPush(std::move(callback));
}

bool CallbackQueue::tryAdd(std::function<void()> callback, Lifetime::Tracker tracker)
{
    return tryAdd(makeSafeCallback(std::move(callback), std::move(tracker)));
}

void CallbackQueue::addDelayed(std::function<void()> callback, std::chrono::milliseconds timeOut) {
    if (stopped_ || !callback) {
        return;
    }

    std::unique_lock<std::mutex> lock(mutex_);
    if (!timeOutedCallbacksThread_.joinable() && !stopped_) {
        timeOutedCallbacksThread_ = std::thread(&CallbackQueue::timeOutedCallbacksLoop, this);
    }
    timeOutedCallbacks_.push(std::make_pair(steady_clock::now() + timeOut, std::move(callback)));
    wakeUpVar_.notify_one();
}

void CallbackQueue::addDelayed(std::function<void()> callback, std::chrono::milliseconds timeOut, Lifetime::Tracker tracker) {
    addDelayed(makeSafeCallback(std::move(callback), std::move(tracker)), timeOut);
}

void CallbackQueue::wait(AwatingType awatingType) {
    if (queueThread_.get_id() == std::this_thread::get_id()) {
        throw std::logic_error("DEAD LOCK in wait() call: can't awaiting inside CallbackQueue");
    }

    std::condition_variable cv;
    bool ready{false};
    do {
        ready = false;
        add([&] {
            std::lock_guard lock(mutex_);
            ready = true;
            cv.notify_all();
        });
        std::unique_lock lock(mutex_);
        cv.wait(lock, [&] { return stopped_ || ready; });
    } while (!stopped_ && awatingType != AwatingType::EGOIST && queue_.size());

    if (awatingType == AwatingType::ALTRUIST) {
        // Because it is impossible to synchronously check "queue_.size()", there may
        // be a situation when the last task is still running, and we have already exited
        // the waiting loop. To do this, call "wait" for the grant again.
        wait(AwatingType::EGOIST);
    }
}

bool CallbackQueue::isWorkingThread() const noexcept {
    return queueThread_.get_id() == std::this_thread::get_id();
}

void CallbackQueue::destroy() {
    {
        std::lock_guard<std::mutex> guard(mutex_);
        stopped_ = true;
        wakeUpVar_.notify_one();
    }
    if (queueThread_.joinable()) {
        queue_.push(nullptr);
        queueThread_.join();
    }
    if (timeOutedCallbacksThread_.joinable()) {
        timeOutedCallbacksThread_.join();
    }
}

void CallbackQueue::onException(std::exception_ptr exptr) noexcept {
    Y_UNUSED(exptr);
}

void CallbackQueue::mainLoop() {
    std::function<void()> callback;
    do {
        callback = nullptr;   // This reset is necessary so that the last processed callback does not hang with all its lambda capture
        queue_.pop(callback); // sleep until there is a callback
        if (callback) {
            try {
                callback();
            } catch (...) {
                onException(std::current_exception());
            }
        }
    } while (callback); // break if function is nullptr (nullptr can be pushed from destructor only)
}

void CallbackQueue::timeOutedCallbacksLoop() {
    std::unique_lock<std::mutex> lock(mutex_);
    while (!stopped_) {
        const auto now = steady_clock::now();
        while (!timeOutedCallbacks_.empty() && timeOutedCallbacks_.top().first <= now) {
            add(timeOutedCallbacks_.top().second);
            timeOutedCallbacks_.pop();
        }
        if (timeOutedCallbacks_.empty()) {
            wakeUpVar_.wait(lock, [this]() {
                return !timeOutedCallbacks_.empty() || stopped_;
            });
        } else {
            const auto until = timeOutedCallbacks_.top().first;

            wakeUpVar_.wait_until(lock, until, [this, until]() {
                return stopped_ || timeOutedCallbacks_.empty() || timeOutedCallbacks_.top().first < until;
            });
        }
    }
}

CallbackQueue::~CallbackQueue() {
    Y_VERIFY(queueThread_.get_id() != std::this_thread::get_id(), "Destory callback queue inside itself: %s", name().c_str());
    destroy();
}

size_t CallbackQueue::size() const {
    return queue_.size();
}

void CallbackQueue::setMaxSize(size_t size) {
    queue_.setMaxSize(size);
}

void CallbackQueue::clear() {
    queue_.clear();
}
