//
// Red-Black (balanced) tree
//
#ifndef RBTREE_H
#define RBTREE_H

#include <cassert>

// A node in Red-Black tree
template <class TreeNodeValue>
class RBTreeNode {
public:
    RBTreeNode* left;   // pointer to the left son
    RBTreeNode* right;  // pointer to the right son
    RBTreeNode* parent; // pointer to the parent
    bool        red;    // the node is red (true) or black (false)
    TreeNodeValue value;  // The value of tree node: normally, it is
                          //     a pair (key, value of key)
    RBTreeNode():
        left(0),
        right(0),
        parent(0),
        red(false),
        value()
    {
    }
};

// Every node of Red-Black tree is of red of black color.
// A leaf is an external NULL-node and has the black color.
// The root is black.
// The sons of red node must be black.
// Every path from the root to a leaf has the same number
// of black nodes (not including a root, but including a leaf).
template <class TreeNodeValue>
class RBTree {
public:
    // Header contains a pointer to the root of the tree
    // as a left son.
    // The tree may be empty, in this case header->left == 0
    RBTreeNode<TreeNodeValue> header;
    int numNodes;

    RBTree():
        header(),
        numNodes(0)
    {
        header.red = true;      // The header has the red color!
    }

    void clear() {
        removeSubtree(root());
    }
    void erase() { clear(); }
    void removeAll() { clear(); }

    ~RBTree() { clear(); }

    RBTreeNode<TreeNodeValue>* root() { return header.left; }
    const RBTreeNode<TreeNodeValue>* root() const { return header.left; }

    int size() const { return numNodes; }

    // Find a key in a subtree
    // In: key         -- a key to find;
    //     subTreeRoot -- a root of a subtree. If subTreeRoot == 0,
    //                    then find in complete tree
    // Out: node       -- a pointer to the node that contains a key,
    //                    if key is in the set,
    //                    or a pointer to a node that should be parent to
    //                    the node with the key given.
    // Return value: 0,   if the key is found;
    //               < 0, if the key is not found, and the key is less
    //                       than the value in the node;
    //               > 0, if the key is not found, and the key is greater
    //                       than the value in the node.
    int find(
        const TreeNodeValue& key,
        const RBTreeNode<TreeNodeValue>* subTreeRoot = nullptr,
        RBTreeNode<TreeNodeValue>** node = nullptr
    ) const {
        const RBTreeNode<TreeNodeValue>* x;        // current node
        const RBTreeNode<TreeNodeValue>* y;        // its parent
        if (subTreeRoot == 0) {
            x = root();
            y = &header;
        } else {
            x = subTreeRoot;
            y = x->parent;
        }
        int n = (-1); // If tree is empty, add the new node as a left son of header
        while (x != 0) {
            const TreeNodeValue& currentKey = x->value;
            if (key < currentKey) {
                n = (-1);
            } else if (currentKey < key) {
                n = 1;
            } else {
                n = 0;
            }
            y = x;
            if (n == 0) {
                // key is found
                if (node != 0)
                    *node = const_cast<RBTreeNode<TreeNodeValue>*>(x);
                return 0;
            } else if (n < 0) {
                x = x->left;
            } else {
                x = x->right;
            }
        }

        // key is not in the tree
        if (node != 0)
            *node = const_cast<RBTreeNode<TreeNodeValue>*>(y);
        return n;
    }

    // Insert a key into the tree:
    //     create a new node and insert it as a leaf.
    // The color of a new node is red.
    // Should be called after the "find" method, which has returned
    //        a nonzero value (i.e. a key was not found)
    // Input: compare = result of comparing with a key of a parent node
    void insert(
        RBTreeNode<TreeNodeValue>* parentNode,
        TreeNodeValue v,
        int compare // negative => add as a left son, positive => right
    ) {
        assert(parentNode != 0 && compare != 0);
        RBTreeNode<TreeNodeValue>* x = new RBTreeNode<TreeNodeValue>();
        x->value = v;
        x->parent = parentNode;
        if (parentNode == &header)
            x->red = false;     // The root of tree is black
        else
            x->red = true;      // Paint the new node in red
        if (compare < 0) {
            // Insert as a left son
            assert(parentNode->left == 0);
            parentNode->left = x;
        } else {
            // Insert as a right son
            assert(parentNode != &header && parentNode->right == 0);
            parentNode->right = x;
        }
        ++numNodes;

        if (x != root())
            rebalanceAfterInsert(x);
    }

    // Rotate a node x to the left    //
    //        x                y      //
    //       / \              / \     //
    //      a   y    --->    x   c    //
    //         / \          / \       //
    //        b   c        a   b      //
    void rotateLeft(RBTreeNode<TreeNodeValue>* x) {
        RBTreeNode<TreeNodeValue>* y = x->right;
        assert(y != 0);
        RBTreeNode<TreeNodeValue>* p = x->parent;
        y->parent = p;
        if (x == p->left) {
            // x is the left son of its parent
            p->left = y;
        } else {
            // x is the right son of its parent
            p->right = y;
        }
        x->right = y->left;
        if (y->left != 0)
            y->left->parent = x;
        y->left = x;
        x->parent = y;
    }

    // Rotate a node x to the right   //
    //        x                y      //
    //       / \              / \     //
    //      y   c    --->    a   x    //
    //     / \                  / \   //
    //    a   b                b   c  //
    void rotateRight(RBTreeNode<TreeNodeValue>* x) {
        RBTreeNode<TreeNodeValue>* y = x->left;
        assert(y != 0);
        RBTreeNode<TreeNodeValue>* p = x->parent;
        y->parent = p;
        if (x == p->left) {
            // x is the left son of its parent
            p->left = y;
        } else {
            // x is the right son of its parent
            p->right = y;
        }
        x->left = y->right;
        if (y->right != 0)
            y->right->parent = x;
        y->right = x;
        x->parent = y;
    }

    void rebalanceAfterInsert(RBTreeNode<TreeNodeValue>* x) {
        assert(x->red);
        while(x != root() && x->parent->red) {
            if (x->parent == x->parent->parent->left) {
                // parent of x is a left son

                RBTreeNode<TreeNodeValue>* y = x->parent->parent->right; // y is the sibling of
                                                 //            parent of x
                if (y != 0 && y->red) {          // if y is red
                    x->parent->red = false;      //     color parent of x in black
                    y->red = false;              //     color y in black
                    x = x->parent->parent;       //     x = grandparent of x
                    x->red = true;               //     color x in red
                } else {                         // else y is black
                    if (x == x->parent->right) { //     if x is a right son
                        x = x->parent;           //         x = parent of x
                        rotateLeft(x);           //         left-rotate x
                    }                            //     end if
                    assert(x == x->parent->left);//     assert: x is a left son
                    x->parent->red = false;      //     color parent of x in black
                    x->parent->parent->red = true;  // color grandparent in red
                    rotateRight(x->parent->parent); // right-rotate grandparent
                }                                // endif
            } else {
                // Mirror case: parent of x is a right son
                assert(x->parent == x->parent->parent->right);

                RBTreeNode<TreeNodeValue>* y = x->parent->parent->left; // y is the sibling of
                                                 //           parent of x
                if (y != 0 && y->red) {          // if y is red
                    x->parent->red = false;      //     color parent of x in black
                    y->red = false;              //     color y in black
                    x = x->parent->parent;       //     x = grandparent of x
                    x->red = true;               //     color x in red
                } else {                         // else y is black
                    if (x == x->parent->left) {  //     if x is a left son
                        x = x->parent;           //         x = parent of x
                        rotateRight(x);          //         right-rotate x
                    }                            //     end if
                    assert(x == x->parent->right); //   assert: x is a right son
                    x->parent->red = false;      //     color parent of x in black
                    x->parent->parent->red = true; //  color grandparent in red
                    rotateLeft(x->parent->parent); //  left-rotate grandparent
                }                                // endif
            }
        } // end while

        // Always color the root in black
        if (x == root()) {
            x->red = false;
        }
    }

    // Remove a subtree and return the number of nodes removed
    int removeSubtree(RBTreeNode<TreeNodeValue>* subTreeRoot) {
        int numRemoved = 0;
        if (subTreeRoot == 0)
            return 0;
        if (subTreeRoot->left != 0)
            numRemoved += removeSubtree(subTreeRoot->left); // recursive call
        if (subTreeRoot->right != 0)
            numRemoved += removeSubtree(subTreeRoot->right); // recursive call

        if (subTreeRoot->parent->left == subTreeRoot)
            subTreeRoot->parent->left = 0;
        else
            subTreeRoot->parent->right = 0;

        delete subTreeRoot;
        ++numRemoved;
        --numNodes;

        assert(numNodes >= 0);

        return numRemoved;
    }

    const RBTreeNode<TreeNodeValue>* minimalNode(
        const RBTreeNode<TreeNodeValue>* subTreeRoot = nullptr
    ) const {
        const RBTreeNode<TreeNodeValue>* x = subTreeRoot;
        if (x == nullptr)
            x = root();
        while (x != 0 && x->left != 0)
            x = x->left;
        return x;
    }

    RBTreeNode<TreeNodeValue>* minimalNode(
        const RBTreeNode<TreeNodeValue>* subTreeRoot = nullptr
    ) {
        return const_cast<RBTreeNode<TreeNodeValue>*>(
            ((const RBTree*) this)->minimalNode(subTreeRoot)
        );
    }

    const RBTreeNode<TreeNodeValue>* maximalNode(
        const RBTreeNode<TreeNodeValue>* subTreeRoot = nullptr
    ) const {
        const RBTreeNode<TreeNodeValue>* x = subTreeRoot;
        if (x == 0)
            x = root();
        while (x != 0 && x->right != 0)
            x = x->right;
        return x;
    }

    RBTreeNode<TreeNodeValue>* maximalNode(
        const RBTreeNode<TreeNodeValue>* subTreeRoot = nullptr
    ) {
        return const_cast<RBTreeNode<TreeNodeValue>*>(
            ((const RBTree*) this)->maximalNode(subTreeRoot)
        );
    }

    const RBTreeNode<TreeNodeValue>* nextNode(const RBTreeNode<TreeNodeValue>* node) const {
        assert(node != 0);
        if (node == &header)
            return minimalNode();

        if (node->right != 0) {
            return minimalNode(node->right);
        } else if (node == node->parent->left) { // node is a left son
            return node->parent;
        } else {                                 // node is a right son
            const RBTreeNode<TreeNodeValue>* x = node->parent;
            while (x == x->parent->right)        // while x is a right son
                x = x->parent;
            return x->parent;
        }
    }

    RBTreeNode<TreeNodeValue>* nextNode(RBTreeNode<TreeNodeValue>* node) {
        return const_cast<RBTreeNode<TreeNodeValue>*>(
            ((const RBTree*) this)->nextNode(node)
        );
    }

    const RBTreeNode<TreeNodeValue>* previousNode(const RBTreeNode<TreeNodeValue>* node) const {
        assert(node != 0);
        if (node == minimalNode())
            return &header;

        if (node->left != 0) {
            return maximalNode(node->left);
        } else if (node == node->parent->right) { // node is a right son
            return node->parent;
        } else {                                  // node is a left son
            const RBTreeNode<TreeNodeValue>* x = node->parent;
            while (x->parent != 0 && x == x->parent->left) // while x is a left son
                x = x->parent;
            if (x->parent != 0) {
                return x->parent;
            } else {
                assert(x == &header);
                return x;
            }
        }
    }

    RBTreeNode<TreeNodeValue>* previousNode(RBTreeNode<TreeNodeValue>* node) {
        return (RBTreeNode<TreeNodeValue>*)(
            const_cast<RBTree*>(this)->previousNode(node)
        );
    }

public:
    class const_iterator {
    protected:
        const RBTree* tree;
        const RBTreeNode<TreeNodeValue>* node;
    public:
        const_iterator(): tree(0), node(0) {}

        const_iterator(
            const RBTree* t,
            const RBTreeNode<TreeNodeValue>* n
        ):
            tree(t),
            node(n)
        {}

        bool operator==(const const_iterator& i) const {
            return (tree == i.tree && node == i.node);
        }

        bool operator!=(const const_iterator& i) const {
            return !operator==(i);
        }

        const_iterator& operator++() {
            node = tree->nextNode(node);
            return *this;
        }

        const_iterator& operator--() {
            node = tree->previousNode(node);
            return *this;
        }

        const_iterator operator++(int) { // Post-increment (don't use it!)
            const_iterator tmp = *this;
            ++(*this);
            return tmp;
        }

        const_iterator operator--(int) { // Post-decrement (don't use it!)
            const_iterator tmp = *this;
            --(*this);
            return tmp;
        }

        const RBTreeNode<TreeNodeValue>& operator*() const { // Dereference
            return *node;
        }
        const RBTreeNode<TreeNodeValue>* operator->() const {
            return node;
        }
    };

    class iterator: public const_iterator {
    public:
        iterator():  
            const_iterator()
        {}

        iterator(RBTree* t, RBTreeNode<TreeNodeValue>* n):
            const_iterator(t, n)
        {}

        RBTreeNode<TreeNodeValue>& operator*() const { // Dereference
            return const_cast<RBTreeNode<TreeNodeValue>&>(
                ((const const_iterator*) this)->operator*()
            );
        }
        RBTreeNode<TreeNodeValue>* operator->() const {
            return const_cast<RBTreeNode<TreeNodeValue>*>(
                ((const const_iterator*) this)->operator->()
            );
        }
    };

    const_iterator begin() const {
        return const_iterator(this, minimalNode());
    }
    iterator begin() {
        return iterator(this, minimalNode());
    }

    const_iterator end() const {
        return const_iterator(this, &header);
    }
    iterator end() {
        return iterator(this, &header);
    }
};

#endif /* RBTREE_H */
