#pragma once

#include "playercore/Log.hpp"
#include "debug/trace.hpp"
#include <atomic>
#include <cassert>
#include <chrono>
#include <mutex> // for std::unique_lock and std::lock_guard definitions

namespace twitch {
namespace ps4 {
enum class CvStatus {
    timeout,
    no_timeout,
};

class Mutex;
class ConditionVariable {
public:
    ConditionVariable();
    ~ConditionVariable();

    ConditionVariable(const ConditionVariable& rhs) = delete;
    ConditionVariable& operator=(const ConditionVariable& rhs) = delete;

    void wait(std::unique_lock<Mutex>& lock);

    template <class Predicate>
    void wait(std::unique_lock<Mutex>& lock, Predicate pred);

    template <class Rep, class Period>
    CvStatus wait_for(std::unique_lock<Mutex>& lock, const std::chrono::duration<Rep, Period>& rel_time);

    template <class Rep, class Period, class Predicate>
    bool wait_for(std::unique_lock<Mutex>& lock, const std::chrono::duration<Rep, Period>& rel_time, Predicate pred);

    void notify_one();
    void notify_all();

private:
    ScePthreadCond m_condition;
    static std::atomic_int64_t m_nextId;
};

class Mutex {
public:
    Mutex();
    ~Mutex();

    Mutex(const Mutex& rhs) = delete;
    Mutex& operator=(const Mutex& rhs) = delete;

    void lock();
    void unlock();

    ScePthreadMutex* getMutex();

private:
    friend class ConditionVariable;

    ScePthreadMutex m_mutex;
    static std::atomic_int64_t m_nextId;
};

class RecursiveMutex {
public:
    RecursiveMutex();
    ~RecursiveMutex();

    RecursiveMutex(const RecursiveMutex& rhs) = delete;
    RecursiveMutex& operator=(const RecursiveMutex& rhs) = delete;

    void lock();
    void unlock();

    ScePthreadMutex* getMutex();

private:
    friend class ConditionVariable;

    ScePthreadMutex m_mutex;
    static std::atomic_int64_t m_nextId;
};

template <class Predicate>
void ConditionVariable::wait(std::unique_lock<Mutex>& lock, Predicate pred)
{
    while (!pred()) {
        scePthreadCondWait(&m_condition, lock.mutex()->getMutex());
    }
}

template <class Rep, class Period>
CvStatus ConditionVariable::wait_for(std::unique_lock<Mutex>& lock, const std::chrono::duration<Rep, Period>& rel_time)
{
    SceKernelUseconds waitTime = rel_time.count() <= 0 ? 1 : std::chrono::duration_cast<std::chrono::microseconds>(rel_time).count();
    int ret = scePthreadCondTimedwait(&m_condition, lock.mutex()->getMutex(), waitTime);
    if (ret == SCE_KERNEL_ERROR_ETIMEDOUT) {
        return CvStatus::timeout;
    } else if (ret < 0) {
        TRACE_ERROR("scePthreadCondTimedwait %d", ret);
        assert(false);
        return CvStatus::timeout;
    } else {
        return CvStatus::no_timeout;
    }
}

template <class Rep, class Period, class Predicate>
bool ConditionVariable::wait_for(std::unique_lock<Mutex>& lock, const std::chrono::duration<Rep, Period>& rel_time, Predicate pred)
{
    static auto freq = sceKernelGetProcessTimeCounterFrequency();
    auto remaining = std::chrono::duration_cast<std::chrono::microseconds>(rel_time).count();

    while (!pred()) {
        if (remaining <= 0LL) {
            return false;
        }

        auto before = sceKernelGetProcessTimeCounter();
        int ret = scePthreadCondTimedwait(&m_condition, lock.mutex()->getMutex(), remaining);
        auto delta((sceKernelGetProcessTimeCounter() - before) * 1000000LL / freq);

        if (ret == SCE_KERNEL_ERROR_ETIMEDOUT) {
            return pred();
        } else if (ret < 0) {
            TRACE_ERROR("scePthreadCondTimedwait returned error code 0x%08x", ret);
            assert(false);
            return false;
        } else {
            assert(delta > 0);
            remaining -= delta;
        }
    }

    return true;
}
}

using ConditionVariable = twitch::ps4::ConditionVariable;
using CvStatus = twitch::ps4::CvStatus;
using Mutex = twitch::ps4::Mutex;
using RecursiveMutex = twitch::ps4::RecursiveMutex;
}
