from libcpp.set cimport set as cppset
from libcpp.vector cimport vector as cppvector
from cpython.ref cimport PyObject, Py_XINCREF, Py_XDECREF
# noinspection PyUnresolvedReferences
from cython.operator cimport predecrement as dec, preincrement as inc, dereference as deref

from .types cimport HostQueueItem, TaskRef


cdef cppclass HostQueueKey:
    float score  # TODO: change to int after SANDBOX-5917
    int task_id
    PyObject* task_ref

    # noinspection PyUnresolvedReferences
    bint lessthan "operator<"(const HostQueueKey& other) const:
        if this.score < other.score:
            return True
        if this.score == other.score:
            return this.task_id < other.task_id
        return False

    # noinspection PyUnresolvedReferences
    bint equalto "operator=="(const HostQueueKey& other) const:
        return this.score == other.score and this.task_id == other.task_id


# noinspection PyUnresolvedReferences
ctypedef cppset[HostQueueKey] Queue


# noinspection PyUnresolvedReferences
cdef class HostQueue:
    # XXX: methods `__iter__`, `cleanup`, 'merge' use almost equal code because of
    # XXX: inability to define custom C++ destructor in Cython,
    # XXX: than could be used for class `GarbageCollector` instead of loop at the end of these methods

    cdef Queue queue

    cpdef push(self, HostQueueItem item):
        cdef HostQueueKey key

        key.score = item.score
        key.task_id = item.task_id
        key.task_ref = <PyObject*>item.task_ref
        Py_XINCREF(key.task_ref)
        it = self.queue.lower_bound(key)
        if it != self.queue.end() and key.equalto(deref(it)):
            self.queue.erase(it)
            dec(it)
        self.queue.insert(it, key)

    def __iter__(self):
        cdef cppvector[Queue.iterator] garbage

        it = self.queue.begin()
        while it != self.queue.end():
            item = deref(it)
            if (<TaskRef>item.task_ref).__task_id == 0:
                garbage.push_back(it)
                inc(it)
                continue
            yield HostQueueItem(item.score, item.task_id, <TaskRef>item.task_ref)
            inc(it)
        for it in garbage:
            Py_XDECREF(deref(it).task_ref)
            self.queue.erase(it)

    cpdef cleanup(self):
        cdef cppvector[Queue.iterator] garbage

        it = self.queue.begin()
        while it != self.queue.end():
            item = deref(it)
            if (<TaskRef>item.task_ref).__task_id == 0:
                garbage.push_back(it)
                inc(it)
                continue
            inc(it)
        for it in garbage:
            Py_XDECREF(deref(it).task_ref)
            self.queue.erase(it)

    cdef merge(self, Queue& merged):
        cdef cppvector[Queue.iterator] garbage

        it = self.queue.begin()
        while it != self.queue.end():
            item = deref(it)
            if (<TaskRef>item.task_ref).__task_id == 0:
                garbage.push_back(it)
                inc(it)
                continue
            merged.insert(item)
            inc(it)
        for it in garbage:
            Py_XDECREF(deref(it).task_ref)
            self.queue.erase(it)

    def __len__(self):
        return self.queue.size()

    def __eq__(self, HostQueue other):
        return self.queue == other.queue


cdef class MergeQueue:
    cdef Queue merged

    def __init__(self, list queues):
        queues.sort(key=lambda _: -len(_))
        for pyobj in queues:
            (<HostQueue>pyobj).merge(self.merged)

    def __iter__(self):
        for item in self.merged:
            if (<TaskRef>item.task_ref).__task_id != 0:
                yield HostQueueItem(item.score, item.task_id, <TaskRef>item.task_ref)
