/* Copyright (c) 2009, Markus Peloquin * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ #ifndef REDBLACK_HPP #define REDBLACK_HPP /* The algorithm is from _Introduction to Algorithms, 2nd ed._ by Cormen et * al. (exactly eleven functions, with the only deviation in RB-Remove()). * The STLifying was done by me. The interface has everything offered by * std::multimap, plus a few minor extras (can cast to bool). */ /* REDBLACK_EXTRA provides the validate() and to_string() functions and some * assertions. */ //#define REDBLACK_EXTRA #include #include #include #ifdef REDBLACK_EXTRA # include # include # include # include #endif template , class A=std::allocator > > class rb_tree { public: // types typedef Key key_type; typedef T mapped_type; typedef std::pair value_type; typedef Cmp key_compare; typedef A allocator_type; class value_compare : public std::binary_function< value_type, value_type, bool> { friend class rb_tree; public: bool operator()(const value_type &x, const value_type &y) const { return x.first < y.first; } }; typedef typename A::size_type size_type; typedef typename A::difference_type difference_type; typedef typename A::reference reference; typedef typename A::const_reference const_reference; typedef typename A::pointer pointer; typedef typename A::const_pointer const_pointer; private: enum color_type { RED, BLACK }; struct node { node() : val(0), l(0), r(0), p(0), color(BLACK) {} ~node() {} // managed by rb_tree, so can use allocator pointer val; node *l; node *r; node *p; color_type color; }; public: // types contd. class iterator : public std::iterator< std::bidirectional_iterator_tag, value_type, difference_type, pointer, reference> { friend class rb_tree; private: iterator(rb_tree *t, node *x) : t(t), x(x) {} public: iterator() : x(0) {} iterator(const iterator &j) : t(const_cast(j.t)), x(const_cast(j.x)) {} ~iterator() {} iterator &operator=(const iterator &j) { iterator &j_ = const_cast(j); t = j_.t; x = j_.x; return *this; } // segfault if *this==end() reference operator*() { return *(x->val); } // leads to segfault if *this==end() pointer operator->() { return x->val; } iterator &operator++() { x = t->successor(x); return *this; } iterator &operator--() { node *y = x == t->_nil ? t->_back : // i==end() t->predecessor(x); if (y != t->_nil) x = y; // will not go past beginning return *this; } iterator operator++(int) { iterator j = *this; ++*this; return j; } iterator operator--(int) { iterator j = *this; --*this; return j; } bool operator==(const iterator &j) const { return j.x == x; } bool operator!=(const iterator &j) const { return j.x != x; } private: rb_tree *t; node *x; }; class const_iterator : public std::iterator< std::bidirectional_iterator_tag, value_type, difference_type, const_pointer, const_reference> { friend class rb_tree; private: const_iterator(const rb_tree *t, const node *x) : t(const_cast(t)), x(const_cast(x)) {} public: const_iterator() : t(0), x(0) {} const_iterator(const const_iterator &j) : t(const_cast(j.t)), x(const_cast(j.x)) {} const_iterator(const iterator &j) : t(const_cast(j.t)), x(const_cast(j.x)) {} ~const_iterator() {} const_iterator &operator=(const const_iterator &j) { const_iterator &j_ = const_cast(j); t = j_.t; x = j_.x; return *this; } const_iterator &operator=(const iterator &j) { iterator &j_ = const_cast(j); t = j_.t; x = j_.x; return *this; } // segfault if *this==end() const_reference operator*() { return *(x->val); } // leads to segfault if *this==end() const_pointer operator->() { return x->val; } const_iterator &operator++() { if (x) x = t->successor(x); return *this; } const_iterator &operator--() { const node *y = x == t->_nil ? t->_back : // i==end() t->predecessor(x); if (y != t->_nil) x = y; // will not go past beginning return *this; } const_iterator operator++(int) { const_iterator j = *this; ++*this; return j; } const_iterator operator--(int) { const_iterator j = *this; --*this; return j; } bool operator==(const const_iterator &j) const { return j.x == x; } bool operator!=(const const_iterator &j) const { return j.x != x; } private: rb_tree *t; node *x; }; typedef std::reverse_iterator reverse_iterator; typedef std::reverse_iterator const_reverse_iterator; // constructors // O(1) explicit rb_tree(const Cmp &cmp=Cmp(), const A &alloc=A()) : _alloc(alloc), _lt(cmp), _root(new node), _front(_root), _back(_root), _nil(_root), _sz(0) { _nil->r = _nil->l = _nil->p = _nil; } // O(n lg n) template rb_tree(In begin, In end, const Cmp &cmp=Cmp(), const A &alloc=A()) : _alloc(alloc), _lt(cmp), _root(new node), _front(_root), _back(_root), _nil(_root), _sz(0) { _nil->r = _nil->l = _nil->p = _nil; while (begin != end) insert(*begin++); } // O(n) rb_tree(const rb_tree &t) : _alloc(t._alloc), _lt(t._lt), _root(new node), _front(_root), _back(_root), _nil(_root), _sz(0) { _nil->r = _nil->l = _nil->p = _nil; const_iterator i = t.begin(); const_iterator end = t.begin(); // since input is ordered, save time by giving hints iterator last = end(); while (i != end) last = insert(last, *i++); } // O(n) ~rb_tree() { clear(); delete _nil; } // O(n) rb_tree &operator=(const rb_tree &t) { rb_tree tmp(t); swap(tmp); } // iterators // O(1) iterator begin() { return iterator(this, _front); } const_iterator begin() const { return const_iterator(this, _front); } // O(1) iterator end() { return iterator(this, _nil); } const_iterator end() const { return const_iterator(this, _nil); } // O(1) reverse_iterator rbegin() { return reverse_iterator(end()); } const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); } // O(1) reverse_iterator rend() { return reverse_iterator(begin()); } const_reverse_iterator rend() const { return const_reverse_iterator(begin()); } // element access // O(1), segfault on empty tree reference front() { return *(_front->val); } const_reference front() const { return *(_front->val); } reference back() { return *(_back->val); } const_reference back() const { return *(_back->val); } // dequeue operations // O(1) void pop_front() { erase(begin()); } void pop_back() { erase(iterator(this, _back)); } // list operations // O(lg n) iterator insert(const value_type &val) { node *n = new node; n->val = _alloc.allocate(1); _alloc.construct(n->val, val); insert(n); firstlast_fix_insert(n); _sz++; return iterator(this, n); } // beware: if hint is chosen randomly, this function is O(n) instead // of O(lg n); the point is to give a guess that's off by // at most k (3 or so), thus giving O(1) complexity iterator insert(iterator hint, const value_type &val) { node *n = new node; n->val = _alloc.allocate(1); _alloc.construct(n->val, val); if (!_sz) insert(n); else { if (hint.x == _nil) hint.x = _back; insert_short(hint.x, n); } firstlast_fix_insert(n); _sz++; return iterator(this, n); } // O(m lg(n + m)), m=end-begin template void insert(In begin, In end) { while (begin != end) insert(*begin++); } // O(lg n) iterator erase(iterator i) { if (i.x == _nil) return i; node *next = successor(i.x); if (i.x == _front) { if (next == _nil) // must be both first and last _back = _nil; _front = next; } else if (i.x == _back) _back = predecessor(i.x); node *n = erase(i.x); _alloc.destroy(n->val); _alloc.deallocate(n->val, 1); delete n; _sz--; return iterator(this, next); } // O(lg n + m), m=count(k) size_type erase(const key_type &k) { std::pair range = equal_range(k); size_type count = std::distance(range.first, range.second); erase(range.first, range.second); return count; } // O(m lg(m + n)), m=end-begin void erase(iterator begin, iterator end) { while (begin != end) begin = erase(begin); } // O(n) void clear(); // map operations // O(lg n) iterator find(const key_type &k) { node *n = search(_root, k); return iterator(this, n); } const_iterator find(const key_type &k) const { rb_tree *t = const_cast(this); node *n = t->search(_root, k); return const_iterator(this, n); } // O(m + lg n), m=count(k) size_type count(const key_type &k) const; // O(lg n) iterator lower_bound(const key_type &k) { node *n = search(_root, k); return iterator(this, find_first(n, k)); } const_iterator lower_bound(const key_type &k) const { rb_tree *t = const_cast(this); node *n = t->search(_root, k); return const_iterator(this, t->find_first(n, k)); } // O(lg n) iterator upper_bound(const key_type &k) { node *n = search(_root, k); iterator j(find_last(n, k)); return ++j; } const_iterator upper_bound(const key_type &k) const { rb_tree *t = const_cast(this); node *n = t->search(_root, k); iterator j(t->find_last(n, k)); return ++j; } // O(lg n) std::pair equal_range(const key_type &k); std::pair equal_range( const key_type &k) const; // capacity operator bool() const { return _sz; } size_type size() const { return _sz; } bool empty() const { return !_sz; } // other // O(1) // this is important for certain algorithms with comparison // functions that depend on time void swap_front() { if (_sz < 2) return; node *a = _front; node *b = successor(a); // swap internal pointers (a has no left) a->l = b->l; b->l = 0; std::swap(a->r, b->r); std::swap(a->p, b->p); // point b's old neighbors to a if (a->p->l == b) a->p->l = a; else a->p->r = a; a->l->p = a; a->r->p = a; // point a's old neighbors to b if (b->p->l == a) b->p->l = b; else b->p->r = b; b->l->p = b; b->r->p = b; _front = b; if (_back == b) _back = a; if (_root == b) _root = a; else if (_root == a) _root = b; } // O(1) void swap(rb_tree &t) { std::swap(_alloc, t._alloc); std::swap(_lt, t._cmp); std::swap(_root, t._root); std::swap(_front, t._front); std::swap(_back, t._back); std::swap(_nil, t._nil); std::swap(_sz, t._sz); } key_compare key_comp() const { return _lt; } value_compare value_comp() const { return value_compare(); } allocator_type get_allocator() const { return _alloc; } // O(n) int cmp(const rb_tree &t) const { const_iterator i = begin(); const_iterator j = t.begin(); const_iterator i_ = end(); const_iterator j_ = t.end(); while (i != i_ && j != j_) { if (_lt(i->first, j->first)) return -1; if (_lt(j->first, i->first)) return 1; } if (i != i_) return 1; // i > j if (j != j_) return -1; // i < j return 0; } #ifdef REDBLACK_EXTRA // O(n) std::string to_string() const { return to_string(_root); } // O(slow) void validate() const; #endif private: #ifdef REDBLACK_EXTRA bool validate_heights(const node *n) const; bool validate_connections() const; std::string to_string(node *n) const; #endif void firstlast_fix_insert(node *n) { if (_front == _nil) _front = _back = n; else if (_front->l != _nil) _front = n; // (n == front->l) else if (_back->r != _nil) _back = n; // (n == back->r) } node *minimum(node *x) { while (x->l != _nil) x = x->l; return x; } node *maximum(node *x) { while (x->r != _nil) x = x->r; return x; } node *predecessor(node *x) { if (x->l != _nil) return maximum(x->l); node *y = x->p; while (y != _nil && x == y->l) { x = y; y = y->p; } return y; } node *successor(node *x) { if (x == _nil) return x; if (x->r != _nil) return minimum(x->r); node *y = x->p; while (y != _nil && x == y->r) { x = y; y = y->p; } return y; } node *search(node *x, const key_type &k) { while (x != _nil) if (_lt(x->val->first, k)) x = x->r; else if (_lt(k, x->val->first)) x = x->l; else return x; return _nil; } node *find_first(node *x, const key_type &k) { x = search(x, k); if (x == _nil) return x; while (x->l != _nil) { node *y = search(x->l, k); if (y == _nil) break; x = y; } return x; } node *find_last(node *x, const key_type &k) { x = search(x, k); if (x == _nil) return x; while (x->r != _nil) { node *y = search(x->r, k); if (y == _nil) break; x = y; } return x; } void left_rotate(node *x); void right_rotate(node *x); void insert(node *z); void insert_short(node *hint, node *z); void insert_fixup(node *z); node *erase(node *z); void erase_fixup(node *x); // _front,_back,_sz maintained together allocator_type _alloc; key_compare _lt; node *_root; node *_front; node *_back; node *_nil; size_type _sz; }; template void rb_tree::clear() { std::queue q; if (_root != _nil) q.push(_root); while (!q.empty()) { node *n = q.front(); q.pop(); if (n->l != _nil) q.push(n->l); if (n->r != _nil) q.push(n->r); _alloc.destroy(n->val); _alloc.deallocate(n->val, 1); delete n; } _sz = 0; _front = _back = _root = _nil; } template typename rb_tree::size_type rb_tree::count(const key_type &k) const { rb_tree *t = const_cast(this); node *top = t->search(_root, k); if (top == _nil) return 0; node *lo = t->find_first(top, k); node *hi = t->find_last(top, k); iterator i(t, lo); iterator j(t, hi); return std::distance(i, j) + 1; } template std::pair< typename rb_tree::iterator, typename rb_tree::iterator> rb_tree::equal_range(const key_type &k) { node *top = search(_root, k); if (top == _nil) { iterator e = end(); return std::make_pair(e, e); } iterator j(this, find_last(top, k)); return std::make_pair( iterator(this, find_first(top, k)), ++j); } template std::pair< typename rb_tree::const_iterator, typename rb_tree::const_iterator> rb_tree::equal_range(const key_type &k) const { rb_tree *t = const_cast(this); node *top = t->search(_root, k); if (top == _nil) { const_iterator e = end(); return std::make_pair(e, e); } const_iterator j(t->find_last(top, k)); return std::make_pair( const_iterator(this, t->find_first(top, k)), ++j); } #ifdef REDBLACK_EXTRA template void rb_tree::validate() const { const_iterator i = begin(); const_iterator j = end(); if (_root != _nil && _root->color != BLACK) std::cerr << "root node not black\n"; if (_nil->color != BLACK) std::cerr << "nil not black\n"; for (; i != j; ++i) { const node *n = i.x; if (n->color == RED) { if (n->l->color != BLACK) std::cerr << "left of red is not black\n"; if (n->r->color != BLACK) std::cerr << "right of red is not black\n"; } if (!validate_heights(n)) std::cerr << "invalid heights\n"; } validate_connections(); } template bool rb_tree::validate_heights(const node *n) const { std::queue q; std::queue heights; std::list leaf_heights; if (n != _nil) { q.push(n); heights.push(0); } while (!q.empty()) { const node *x = q.front(); int height = heights.front(); q.pop(); heights.pop(); if (x->color == BLACK) height++; if (x->r == _nil || x->l == _nil) leaf_heights.push_back(height); if (x->l != _nil) { q.push(x->l); heights.push(height); } if (x->r != _nil) { q.push(x->r); heights.push(height); } } std::list::iterator i = leaf_heights.begin(); std::list::iterator j = leaf_heights.end(); for (++i; i != j; ++i) if (*i != leaf_heights.front()) return false; return true; } template bool rb_tree::validate_connections() const { std::queue q; std::set seen; q.push(_root); bool good = true; while (!q.empty()) { const node *n = q.front(); q.pop(); if (!seen.insert(n).second) std::cerr << "repeated node " << n << '\n'; if (n != _root && n->p->l != n && n->p->r != n) { std::cerr << "node is not child of its parent\n"; good = false; } if (n->l != _nil && n->l->p != n) { std::cerr << "node is not parent of left child\n"; good = false; } if (n->r != _nil && n->r->p != n) { std::cerr << "node is not parent of right child\n"; good = false; } if (n->l != _nil) q.push(n->l); if (n->r != _nil) q.push(n->r); } return good; } template std::string rb_tree::to_string(node *n) const { if (n == _nil) return "-"; std::ostringstream out; out << '(' << to_string(n->l) << ' ' << n->val->first << (n->color == RED ? 'R' : 'B') << ' ' << to_string(n->r) << ')'; return out.str(); } #endif template void rb_tree::left_rotate(node *x) { #ifdef REDBLACK_EXTRA assert(x != _nil); assert(x->r != _nil); #endif // set y node *y = x->r; // turn y's left subtree into x's right subtree x->r = y->l; if (y->l != _nil) y->l->p = x; // link x's parent to y y->p = x->p; if (x->p == _nil) _root = y; else if (x == x->p->l) x->p->l = y; else x->p->r = y; // put x on y's left y->l = x; x->p = y; } template void rb_tree::right_rotate(node *x) { #ifdef REDBLACK_EXTRA assert(x != _nil); assert(x->l != _nil); #endif // set y node *y = x->l; // turn y's right subtree into x's left subtree x->l = y->r; if (y->r != _nil) y->r->p = x; // link x's parent to y y->p = x->p; if (x->p == _nil) _root = y; else if (x == x->p->r) x->p->r = y; else x->p->l = y; // put x on y's right y->r = x; x->p = y; } template void rb_tree::insert(node *z) { node *y = _nil; node *x = _root; while (x != _nil) { y = x; if (_lt(z->val->first, x->val->first)) x = x->l; else x = x->r; } z->p = y; if (y == _nil) _root = z; else if (_lt(z->val->first, y->val->first)) y->l = z; else y->r = z; z->l = z->r = _nil; z->color = RED; insert_fixup(z); } template void rb_tree::insert_short(node *hint, node *z) { node *y = _nil; node *x = hint; if (_lt(z->val->first, x->val->first)) { // move left do { y = x; x = predecessor(x); } while (x != _nil && _lt(z->val->first, x->val->first)); // key[x] <= key[z] < key[y], x may be nil if (x == _nil || y->l == _nil) { y->l = z; z->p = y; } else { x->r = z; z->p = x; } } else { // move right do { y = x; x = successor(x); } while (x != _nil && _lt(x->val->first, z->val->first)); // key[y] < key[z] <= key[x], x may be nil if (x == _nil || y->r == _nil) { y->r = z; z->p = y; } else { x->l = z; z->p = x; } } z->l = z->r = _nil; z->color = RED; insert_fixup(z); } template void rb_tree::insert_fixup(node *z) { node *y; while (z->p->color == RED) if (z->p == z->p->p->l) { y = z->p->p->r; if (y->color == RED) { z->p->color = BLACK; y->color = BLACK; z->p->p->color = RED; z = z->p->p; } else { if (z == z->p->r) { z = z->p; left_rotate(z); } z->p->color = BLACK; z->p->p->color = RED; right_rotate(z->p->p); } } else { y = z->p->p->l; if (y->color == RED) { z->p->color = BLACK; y->color = BLACK; z->p->p->color = RED; z = z->p->p; } else { if (z == z->p->l) { z = z->p; right_rotate(z); } z->p->color = BLACK; z->p->p->color = RED; left_rotate(z->p->p); } } _root->color = BLACK; } template struct rb_tree::node * rb_tree::erase(node *z) { node *x; node *y; if (z->l == _nil || z->r == _nil) y = z; else y = successor(z); if (y->l != _nil) x = y->l; else x = y->r; x->p = y->p; if (y->p == _nil) _root = x; else if (y == y->p->l) y->p->l = x; else y->p->r = x; if (y != z) { // main deviation from Cormen's implementation; the // goal is to not invalidate existing iterators; at this // point, y is detached, so make things pointing to z point // to y, make y point to what z points to, and copy color y->p = z->p; y->l = z->l; y->r = z->r; std::swap(y->color, z->color); if (y->p == _nil) _root = y; else if (y->p->l == z) y->p->l = y; else y->p->r = y; y->l->p = y; y->r->p = y; } if (z->color == BLACK) erase_fixup(x); return z; } template void rb_tree::erase_fixup(node *x) { node *w; while (x != _root and x->color == BLACK) if (x == x->p->l) { w = x->p->r; if (w->color == RED) { w->color = BLACK; x->p->color = RED; left_rotate(x->p); w = x->p->r; } if (w->l->color == BLACK && w->r->color == BLACK) { w->color = RED; x = x->p; } else { if (w->r->color == BLACK) { w->l->color = BLACK; w->color = RED; right_rotate(w); w = x->p->r; } w->color = x->p->color; x->p->color = BLACK; w->r->color = BLACK; left_rotate(x->p); x = _root; } } else { w = x->p->l; if (w->color == RED) { w->color = BLACK; x->p->color = RED; right_rotate(x->p); w = x->p->l; } if (w->r->color == BLACK && w->l->color == BLACK) { w->color = RED; x = x->p; } else { if (w->l->color == BLACK) { w->r->color = BLACK; w->color = RED; left_rotate(w); w = x->p->l; } w->color = x->p->color; x->p->color = BLACK; w->l->color = BLACK; right_rotate(x->p); x = _root; } } x->color = BLACK; } template inline bool operator==( const rb_tree &x, const rb_tree &y) { return x.cmp(y) == 0; } template inline bool operator!=( const rb_tree &x, const rb_tree &y) { return x.cmp(y) != 0; } template inline bool operator<( const rb_tree &x, const rb_tree &y) { return x.cmp(y) < 0; } template inline bool operator>( const rb_tree &x, const rb_tree &y) { return x.cmp(y) > 0; } template inline bool operator<=( const rb_tree &x, const rb_tree &y) { return x.cmp(y) <= 0; } template inline bool operator>=( const rb_tree &x, const rb_tree &y) { return x.cmp(y) >= 0; } #endif