C++数据结构——二叉树之三红黑树

本次实现的红黑树为基于上篇实现的普通二叉树之上实现的,普通二叉树的实现已经在https://blog.csdn.net/qq811299838/article/details/104038745这篇文章中列出,此处就不再放出来了。

由于网上已经有很多红黑树的算法介绍,本文将不再介绍算法,只提供代码实现以及测试代码,如有问题,欢迎指出!

关于结点删除:选用被删除结点中,左子树深度比右子树大,则选择左子树中最大结点作为替换结点;否则,选择右子树中最小结点作为替换结点。

编译环境:GCC 7.3、vs 2005

红黑树的代码如下:

#ifndef __RBTREE_H__
#define __RBTREE_H__

#if __cplusplus >= 201103L
#include <type_traits> // std::forward、std::move
#endif

#if __cplusplus >= 201103L
#define null nullptr
#else
#define null NULL
#endif 

#include "btree.h"

template<typename _Tp>
struct Comparator
{
    int operator()(const _Tp &a1, const _Tp &a2)
    { 
        if(a1 < a2) return 1;
        if(a2 < a1) return -1;
        return 0;
    }
};

template<typename _Tp, typename _Compare = Comparator<_Tp>>
class RBTree
{
public:
    typedef _Tp                  value_type;
    typedef _Tp &                reference;
    typedef _Tp *                pointer;
    typedef const _Tp &          const_reference;
    typedef unsigned long        size_type;
    typedef _Compare             compare_type;

#if __cplusplus >= 201103L
    typedef _Tp &&        rvalue_reference;
#endif

private:
    enum COLOR
    {
        RED,
        BLACK
    };

public:
    typedef BinaryTree<value_type>        tree_type;
    typedef typename tree_type::node_type node_type;
    typedef typename node_type::color_type color_type;

private:
    void _M_swap_color(node_type *node1, node_type *node2)
    {
        if(null == node1 && null != node2)
        { node2->_M_color = COLOR::BLACK; }
        else if(null == node2 && null != node1)
        { node1->_M_color = COLOR::BLACK; }
        else if(null != node1 && null != node2)
        {
            color_type c1 = node1->color();
            node1->_M_color = node2->color();
            node2->_M_color = c1;
        }
    }

public:
    typedef node_type*(*iterator_func)(node_type*);
    
    template<iterator_func _Next, iterator_func _Prev>
    struct iterator_impl
    {
        node_type *_M_node;

        iterator_impl(node_type *node = null)
         : _M_node(node) { }
        
        iterator_impl operator++()
        { return iterator_impl(_M_node = _Next(_M_node)); }

        iterator_impl operator++(int)
        { 
            iterator_impl ret(_M_node);
            _M_node = _Next(_M_node);
            return ret;
        }

        iterator_impl operator--()
        { return iterator_impl(_M_node = _Prev(_M_node)); }

        iterator_impl operator--(int)
        {
            iterator_impl ret(_M_node);
            _M_node = _Prev(_M_node);
            return ret;
        }

        reference operator*()
		{ return *_M_node->value(); }

		pointer operator->()
		{ return _M_node->value(); }

        bool operator==(const iterator_impl &it) const 
		{ return _M_node == it._M_node; }

		bool operator!=(const iterator_impl &it) const
		{ return _M_node != it._M_node; }
    };
    
    template<iterator_func _Next, iterator_func _Prev>
    struct const_iterator_impl
    {
        const node_type *_M_node;

        const_iterator_impl(const node_type *node = null)
         : _M_node(node) { }

        const_iterator_impl(const iterator_impl<_Next, _Prev> &it)
         : _M_node(it._M_node) { }
        
        const_iterator_impl operator++()
        { return const_iterator_impl(_M_node = _Next(const_cast<node_type*>(_M_node))); }

        const_iterator_impl operator++(int)
        { 
            const_iterator_impl ret(_M_node);
            _M_node = _Next(const_cast<node_type*>(_M_node));
            return ret;
        }

        const_iterator_impl operator--()
        { return const_iterator_impl(_M_node = _Prev(const_cast<node_type*>(_M_node))); }

        const_iterator_impl operator--(int)
        {
            const_iterator_impl ret(_M_node);
            _M_node = _Prev(const_cast<node_type*>(_M_node));
            return ret;
        }

        reference operator*()
		{ return *_M_node->value(); }

		pointer operator->()
		{ return _M_node->value(); }

        bool operator==(const const_iterator_impl &it) const 
		{ return _M_node == it._M_node; }

		bool operator!=(const const_iterator_impl &it) const
		{ return _M_node != it._M_node; }
    };

public:
    typedef iterator_impl<&tree_type::middle_next, &tree_type::middle_previous> iterator;
    typedef iterator_impl<&tree_type::middle_previous, &tree_type::middle_next> reverse_iterator;
    typedef const_iterator_impl<&tree_type::middle_next, &tree_type::middle_previous> const_iterator;
    typedef const_iterator_impl<&tree_type::middle_previous, &tree_type::middle_next> const_reverse_iterator;

public:
    RBTree() { }

    RBTree(const RBTree &t)
     : _M_tree(t._M_tree) { }

#if __cplusplus >= 201103L
    RBTree(RBTree &&t)
     : _M_tree(std::move(t._M_tree)) { }
#endif

    size_type size() const 
    { return _M_tree.size(); }

    size_type depth() const 
    { return _M_tree.depth(); }

    const tree_type& get_tree() const 
    { return _M_tree; }

    bool empty() const 
    { return size() == 0; }

    iterator begin()
    { return iterator(_M_tree.left_child_under(_M_tree.root())); }

    iterator end() 
    { return iterator(); }

    const_iterator begin() const 
    { return const_iterator(_M_tree.left_child_under(_M_tree.root())); }

    const_iterator end() const 
    { return const_iterator(); }

    const_iterator cbegin() const 
    { return begin(); }

    const_iterator cend() const 
    { return end(); }

    reverse_iterator rbegin() 
    { return reverse_iterator(_M_tree.right_child_under(_M_tree.root())); }

    reverse_iterator rend() 
    { return reverse_iterator(); }

    const_reverse_iterator rbegin() const 
    { return const_reverse_iterator(_M_tree.right_child_under(_M_tree.root())); }

    const_reverse_iterator rend() const 
    { return const_reverse_iterator(); }

    const_reverse_iterator crbegin() const 
    { return rbegin(); }

    const_reverse_iterator crend() const 
    { return rend(); }

    void erase(const_iterator it)
    { 
        if(it == end())
        { return; }

        node_type *node = const_cast<node_type*>(it._M_node);
        node_type *reserve = _M_tree.get_erase_reserve(node);
        node->swap(reserve);
        if(reserve->color() != COLOR::RED)
        { 
            _M_erase_adjust(reserve); 
            _M_swap_color(node, reserve);
        }
        
        _M_tree.erase_leaf(reserve); 
    }
        
    iterator insert(const_reference v)
    { return _M_insert(v); }

#if __cplusplus >= 201103L
    iterator insert(rvalue_reference v)
    { return _M_insert(std::move(v)); }
#endif

    iterator find(const_reference v)
    { return _M_find<value_type, compare_type>(v); }

    template<typename _CompareType>
    const_iterator find(const_reference v) const
    { return _M_find<value_type, _CompareType>(v); }

private:
    iterator _M_insert(const_reference v)
    {
        iterator found;
        if(_M_find_and_insert<value_type, compare_type>(v, found))
        { *found = v; }
        return found;
    }
#if __cplusplus >= 201103L
    iterator _M_insert(rvalue_reference v)
    {
        iterator found;
        if(_M_find_and_insert<value_type, compare_type>(std::move(v), found))
        { *found = v; }
        return found;
    }
#endif

    template<typename _InputType, typename _CompareType>
    iterator _M_find(const _InputType &input)
    {
        node_type *node = _M_tree.root();
        while(null != node)
        {
            int res = _CompareType()(input, *node->value());
            if(res == 0)
            { return iterator(node); }
            if(res > 0)
            { node = node->left_child(); }
            else
            { node = node->right_child(); }
        }
        return iterator();
    }

    void _M_erase_adjust(node_type *node)
    {
        // 父结点、兄弟结点、远侄子、近侄子
        // 当删除结点是左孩子时,兄弟结点的左孩子就是近侄子,兄弟结点的右孩子就是远侄子
        node_type *parent = null, *brother = null, *far_nephew = null, *near_nephew = null;

        while(null != node && node->color() != COLOR::RED)
        {
            // 先找到关系
            parent = node->parent();
            if(null == parent) { }
            else if(parent->left_child() == node)
            {
                brother = parent->right_child();
                if(null != brother)
                {
                    far_nephew = brother->right_child();
                    near_nephew = brother->left_child();
                }
            }
            else
            {
                brother = parent->left_child();
                if(null != brother)
                {
                    far_nephew = brother->left_child();
                    near_nephew = brother->right_child();
                }
            }

            // 当兄弟结点为红色结点时
            if(null != brother && brother->color() == COLOR::RED)
            {
                brother->_M_color = COLOR::BLACK;
                parent->_M_color = COLOR::RED;
                if(parent->left_child() == node)
                { _M_tree.left_rotate(parent); }
                else
                { _M_tree.right_rotate(parent); }
            }
            // 当远侄子结点为红色结点时
            else if(null != far_nephew && far_nephew->color() == COLOR::RED)
            {
                brother->_M_color = parent->color();
                parent->_M_color = COLOR::BLACK;
                far_nephew->_M_color = COLOR::BLACK;
                if((null == parent ? null : parent->left_child()) == node)
                { _M_tree.left_rotate(parent); }
                else
                { _M_tree.right_rotate(parent); }
                break;
            }
            // 当近侄子结点为红色结点时
            else if(null != near_nephew && near_nephew->color() == COLOR::RED)
            {
                brother->_M_color = COLOR::RED;
                near_nephew->_M_color = COLOR::BLACK;
                if(parent->left_child() == node)
                { _M_tree.right_rotate(brother); }
                else
                { _M_tree.left_rotate(brother); }
            }
            // 当兄弟结点的孩子结点都是黑色结点时
            else
            {
                if(null != brother)
                { brother->_M_color = COLOR::RED; }
                if(null != parent && parent->color() == COLOR::RED)
                { 
                    parent->_M_color = COLOR::BLACK; 
                    break;
                }
                node = parent;
            }
        }
    }

    void _M_insert_adjust(node_type *node)
    {
        while(null != node)
        {
            node_type *parent = node->parent();
            node_type *grand_parent = null == parent ? null : parent->parent();
            node_type *uncle = null;
            if((null == grand_parent ? null : grand_parent->left_child()) == parent)
            { uncle = null == grand_parent ? null : grand_parent->right_child(); }
            else
            { uncle = null == grand_parent ? null : grand_parent->left_child(); }
            // 当父结点是红色结点时
            if(null != parent && parent->color() == COLOR::RED)
            {
                // 当叔叔结点是红色结点时
                if(null != uncle && uncle->color() == COLOR::RED)
                {
                    parent->_M_color = COLOR::BLACK;
                    uncle->_M_color = COLOR::BLACK;
                    grand_parent->_M_color = COLOR::RED;
                    node = grand_parent;
                    continue;
                }
                // 当新结点是左孩子时
                if((null == parent ? null : parent->left_child()) == node)
                {
                    // 父结点是左孩子
                    if((null == grand_parent ? null : grand_parent->left_child()) == parent)
                    {
                        _M_swap_color(parent, grand_parent);
                        _M_tree.right_rotate(grand_parent);
                        break;
                    }
                    else
                    {
                        _M_tree.right_rotate(parent);
                        node = parent;
                        continue;
                    }
                }
                // 新结点是右孩子
                else
                {
                    // 父结点是左孩子
                    if((null == grand_parent ? null : grand_parent->left_child()) == parent)
                    {
                        _M_tree.left_rotate(parent);
                        node = parent;
                        continue;
                    }
                    else
                    {
                        _M_swap_color(parent, grand_parent);
                        _M_tree.left_rotate(grand_parent);
                        continue;
                    }
                }
            }
            break;
        }
        _M_tree.root()->_M_color = COLOR::BLACK;
    }


    /* 插入结点,如果结点不存在则插入新结点
     * @input  插入的值
     * @result  结点的迭代器
     * @return  如果结点本来已存在,则返回true
     */
    template<typename _InputType, typename _CompareType>
    bool _M_find_and_insert(const _InputType &input, iterator &result)
    {
        if(empty())
        { 
            result._M_node = _M_tree.append_root(input);
            return false; 
        }

        node_type *node = _M_tree.root();
        while(true)
        {
            int res = _CompareType()(input, *node->value());
            if(res == 0)
            { 
                result._M_node = node;
                return true; 
            }
            if(res > 0)
            {
                if(null == node->left_child())
                {
                    node = _M_tree.append_left(node, input);
                    break;
                }
                node = node->left_child();
            }
            else
            {
                if(null == node->right_child())
                {
                    node = _M_tree.append_right(node, input);
                    break;
                }
                node = node->right_child();
            }
        }
        _M_insert_adjust(node);
        result._M_node = node;
        return false;
    }

#if __cplusplus >= 201103L
    template<typename _InputType, typename _CompareType>
    bool _M_find_and_insert(_InputType &&input, iterator &result)
    {
        if(empty())
        { 
            result._M_node = _M_tree.append_root(input);
            return false; 
        }

        node_type *node = _M_tree.root();
        while(true)
        {
            int res = _CompareType()(std::forward<value_type>(input), *node->value());
            if(res == 0)
            { 
                result._M_node = node;
                return true; 
            }
            if(res > 0)
            {
                if(null == node->left_child())
                {
                    node = _M_tree.append_left(node, std::move(input));
                    break;
                }
                node = node->left_child();
            }
            else
            {
                if(null == node->right_child())
                {
                    node = _M_tree.append_right(node, std::move(input));
                    break;
                }
                node = node->right_child();
            }
        }
        _M_insert_adjust(node);
        result._M_node = node;
        return false;
    }
#endif

private:
    tree_type _M_tree;
};

#endif // __RBTREE_H__

测试代码:

#include <iostream>
#include <list>
#include "rbtree.h"

#ifdef _WIN32
#include <windows.h>
#endif

#if __cplusplus < 201103L
#include <sstream>
#endif

#define MAX_NUMBER_BIT 5

static std::string get_string(int v)
{ 
#if __cplusplus < 201103L
	std::stringstream ss;
	ss << v;
	std::string tmp = ss.str();
#else
	std::string tmp = std::to_string(v);
#endif
	std::string result = "";
	for(std::size_t i = 0; i < (MAX_NUMBER_BIT - tmp.size()) / 2; ++i)
	{ result += ' '; }
	result += tmp;
	for(std::size_t i = 0; i < (MAX_NUMBER_BIT - tmp.size()) / 2; ++i)
	{ result += ' '; }
	return result; 
}
static void set_red()
{
#ifdef _WIN32
	SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), FOREGROUND_RED);
#elif defined(__linux__)
	std::cout << "\033[31m";
#endif
}

static void set_black()
{
#ifdef _WIN32
	SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), FOREGROUND_BLUE);
#elif defined(__linux__)
	std::cout << "\033[34m";
#endif
}

static void set_default()
{
#ifdef _WIN32
	SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), FOREGROUND_GREEN);
#elif defined(__linux__)
	std::cout << "\033[32m";
#endif
}

struct T
{
	T(int v = 0) : value(v) { }

	void print()
	{ std::cout << get_string(value); }

	bool operator<(const T& t) const 
	{ return value < t.value; }

	int value;
};

typedef RBTree<T> TestTree;
typedef TestTree::node_type NodeType;

static void print_tree(const BinaryTree<T> &tree)
{
	std::list<const NodeType*> s;
	s.push_back(tree.root());
	bool break_flag = false;
	while(!break_flag)
	{
		break_flag = true;
		std::size_t count = s.size();
		int print_count = 0;
		while(count-- > 0)
		{
			const NodeType *t = s.front();
			s.pop_front();
			if(null == t)
			{ 
				set_black();
				T().print();
				set_default();
				s.push_back(null);
				s.push_back(null); 
			}
			else 
			{ 
				if(t->color() == 0)
				{ set_red(); }
				else 
				{ set_black(); }
				t->value()->print();
				set_default();
				s.push_back(t->left_child());
				s.push_back(t->right_child());
				if(break_flag)
				{ break_flag = null == t->left_child() && null == t->right_child(); }
			}
			if(++print_count % 2 == 0)
			{ std::cout << "|"; }
		}
		std::cout << std::endl;
	}
}
void main_func()
{
	TestTree avl;
	std::cout << "insert: ---->" << avl.insert(T(10))->value << std::endl;
	print_tree(avl.get_tree());
	std::cout << "insert: ---->" << avl.insert(T(30))->value << std::endl;
	print_tree(avl.get_tree());
	std::cout << "insert: ---->" << avl.insert(T(1))->value << std::endl;
	print_tree(avl.get_tree());
	std::cout << "insert: ---->" << avl.insert(T(31))->value << std::endl;
	print_tree(avl.get_tree());
	TestTree::iterator it1 = avl.insert(T(32));
	std::cout << "insert: ---->" << it1->value << std::endl;
	print_tree(avl.get_tree());
	std::cout << "insert: ---->" << avl.insert(T(33))->value << std::endl;
	print_tree(avl.get_tree());
	std::cout << "insert: ---->" << avl.insert(T(34))->value << std::endl;
	print_tree(avl.get_tree());
	std::cout << "insert: ---->" << avl.insert(T(35))->value << std::endl;
	print_tree(avl.get_tree());
	std::cout << "insert: ---->" << avl.insert(T(36))->value << std::endl;
	print_tree(avl.get_tree());
	std::cout << "insert: ---->" << avl.insert(T(37))->value << std::endl;
	print_tree(avl.get_tree());
	std::cout << "insert: ---->" << avl.insert(T(38))->value << std::endl;

	std::cout << "size: " << avl.size() << std::endl;
	std::cout << "depth: " << avl.depth() << std::endl;
	print_tree(avl.get_tree());

	std::cout << "--------------erase node 32----------------" << std::endl;
	avl.erase(it1);
	print_tree(avl.get_tree());
	std::cout << "--------------erase node 35----------------" << std::endl;
	avl.erase(avl.find(T(35)));
	print_tree(avl.get_tree());

	std::cout << std::endl << "--------------iterator visit-----------" << std::endl;
	for(TestTree::const_iterator it = avl.begin(); it != avl.end(); ++it)
	{
		std::cout << it->value << ' ';
	}
	std::cout << std::endl;
}

int main() 
{
	set_default();
	main_func();
	system("pause");
	return 0;
}

测试结果:

发布了19 篇原创文章 · 获赞 1 · 访问量 2773

猜你喜欢

转载自blog.csdn.net/qq811299838/article/details/104291405