#pragma once

#include <util/system/yassert.h>

namespace NSolomon::NIntern {
namespace NPrivate {

template <typename T>
struct TDefaultComparator {
    static constexpr int Compare(T a, T b) noexcept {
        return (a < b) ? -1 : (a > b) ? 1 : 0;
    }
};

template <typename T>
class TTaggedPtr {
public:
    explicit TTaggedPtr(T* ptr)
        : Ptr_(reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(ptr) | 1))
    {
    }

    operator T*() const noexcept {
        auto ptr = reinterpret_cast<uintptr_t>(Ptr_) & ~(uintptr_t{1});
        return reinterpret_cast<T*>(ptr);
    }

    TTaggedPtr<T>& operator=(T* ptr) noexcept {
        auto tag = reinterpret_cast<uintptr_t>(Ptr_) & 1;
        auto newPtr = reinterpret_cast<uintptr_t>(ptr) | tag;
        Ptr_ = reinterpret_cast<T*>(newPtr);
        return *this;
    }

    bool HasTag() const noexcept {
        return (reinterpret_cast<uintptr_t>(Ptr_) & 1) != 0;
    }

    void Tag() noexcept {
        auto ptr = reinterpret_cast<uintptr_t>(Ptr_) | 1;
        Ptr_ = reinterpret_cast<T*>(ptr);
    }

    void UnTag() noexcept {
        auto ptr = reinterpret_cast<uintptr_t>(Ptr_) & ~(uintptr_t{1});
        Ptr_ = reinterpret_cast<T*>(ptr);
    }

    void SetTag(bool tag) noexcept {
        auto ptr = (reinterpret_cast<uintptr_t>(Ptr_) & ~(uintptr_t{1})) | uintptr_t{tag};
        Ptr_ = reinterpret_cast<T*>(ptr);
    }

private:
    T* Ptr_;
};

} // namespace NPrivate

template <typename TKey, typename TDerived>
struct TRbTreeNode {
    using TSelf = TRbTreeNode<TKey, TDerived>;

    const TKey Key;
    TSelf* Left;
    NPrivate::TTaggedPtr<TSelf> Right; // if HasTag() == true, then the node is red,
                                       // otherwise it is black

    TRbTreeNode(TKey key)
        : Key(key)
        , Left(nullptr)
        , Right(nullptr)
    {
    }

    TDerived* AsDerived() noexcept {
        return static_cast<TDerived*>(this);
    }
}; // +24 bytes per node

/**
 * Intrusive left-leaning red–black tree, because we do not want to waste memory
 * for additional pointer to a parent node.
 *
 * Based on https://github.com/jemalloc/jemalloc/blob/dev/include/jemalloc/internal/rb.h
 * Read more https://en.wikipedia.org/wiki/Left-leaning_red%E2%80%93black_tree
 */
template <typename TKey, typename TNodeDerived, typename TComparator = NPrivate::TDefaultComparator<TKey>>
class TRbTree {
    using TNode = TRbTreeNode<TKey, TNodeDerived>;

public:
    TRbTree()
        : Root_(nullptr)
    {
    }

    TNodeDerived* Find(TKey key) const noexcept {
        int cmp;
        TNode* node = Root_;
        while (node && (cmp = TComparator::Compare(key, node->Key)) != 0) {
            if (cmp < 0) {
                node = node->Left;
            } else {
                node = node->Right;
            }
        }
        return node ? node->AsDerived() : nullptr;
    }

    void Insert(TNode* node) {
        struct {
            TNode* Node;
            int Cmp;
        } path[sizeof(void*) << 4], *pathp; // up to 2^128 nodes

        // (1) wind
        path->Node = Root_;
        for (pathp = path; pathp->Node; pathp++) {
            int cmp = pathp->Cmp = TComparator::Compare(node->Key, pathp->Node->Key);
            Y_ASSERT(cmp != 0);
            if (cmp < 0) {
                pathp[1].Node = pathp->Node->Left;
            } else {
                pathp[1].Node = pathp->Node->Right;
            }
        }
        pathp->Node = node;

        // (2) unwind
        for (pathp--; reinterpret_cast<uintptr_t>(pathp) >= reinterpret_cast<uintptr_t>(path); pathp--) {
            TNode* cnode = pathp->Node;
            if (pathp->Cmp < 0) {
                TNode* left = pathp[1].Node;
                cnode->Left = left;

                if (left->Right.HasTag()) {
                    TNode* leftLeft = left->Left;
                    if (leftLeft && leftLeft->Right.HasTag()) {
                        // fix up 4-node
                        leftLeft->Right.UnTag();

                        // rotate right
                        TNode* tnode = cnode->Left;
                        cnode->Left = tnode->Right;
                        tnode->Right = cnode;
                        cnode = tnode;
                    }
                } else {
                    return;
                }
            } else {
                TNode* right = pathp[1].Node;
                cnode->Right = right;

                if (right->Right.HasTag()) {
                    TNode* left = cnode->Left;
                    if (left && left->Right.HasTag()) {
                        // split 4-node
                        left->Right.UnTag();
                        right->Right.UnTag();
                        cnode->Right.Tag();
                    } else {
                        // lean left
                        bool isRed = cnode->Right.HasTag();

                        // rotate left
                        TNode* tnode = cnode->Right;
                        cnode->Right = tnode->Left;
                        tnode->Left = cnode;

                        tnode->Right.SetTag(isRed);
                        cnode->Right.Tag();

                        cnode = tnode;
                    }
                } else {
                    return;
                }
            }
            pathp->Node = cnode;
        }

        // set root, and make it black
        Root_ = path->Node;
        Root_->Right.UnTag();
    }

    // Remove() intentionally not implemented

private:
    TNode* Root_;
};

} // namespace NSolomon::NIntern
