带你手写红黑树(下)实现篇
上一篇 带你手写红黑树(上)原理篇 讲解了红黑树的原理,这是下篇,用 C++ 实现一棵红黑树。
包含内容如下:
- 使用 C++ 模板实现泛型
- 包含二叉搜索树的基本内容
- 递归和非递归实现先序、中序、后续遍历
- 层序遍历
- 递归和迭代实现查找指定键值的节点
- 查找最大键值最大或最小的节点
- 递归和迭代实现获取层数
- 获取树拥有的元素个数
- 判断树是否为空
- 清空树
- 插入指定键值
- 删除指定键值或节点指针的节点
- 判断一棵树是不是红黑树
- 按节点的相对位置和颜色在控制台打印红黑树
泛型的实现:
- 由于工程上,遍历很少是打印数据,所以为了更好的支持泛型,所有的遍历操作均额外提供一个参数:
const std::function<void(const _Ty&)>&
,用于实现更多的遍历操作,该参数可以是任何返回void
,且参数为const _Ty&
的函数符,_Ty
是红黑树键值类型
对于二叉搜索树的基本内容部分,就不再讲解。
关于左右旋,插入删除操作,你可以对比着我的原理篇和代码进行体会。
下面重点讲解一下:
- 如何判断我们写的红黑树是正确的,即如何判断一棵树是红黑树。
- 为了方便检查测试红黑树的状态,通过控制台打印数据,避免麻烦地一步一步调试查看内存。
红黑树正确性判断
很多时候,写完代码往往很难判断其正确性,很难判断正确是不是偶然。
手写红黑树更是如此,很有可能你写的红黑树暗藏很多 bug,导致红黑树其实并不绝对正确,某些时候这是致命的。
在进行讲解红黑树正确性判断前,先回顾一下红黑树的规则:
- 节点是红色或者黑色
- 根节点是黑色
- 每个叶子节点都是黑色的空节点
- 每个红节点的两个子节点都必须是黑色
- 从任意节点到其叶子节点的所有路径都包含相同数目的黑色节点
在我们手写红黑树中,需要检查的有:规则 2,4,5。
如何检查呢?
要检测 2:这很简单,不用再说。
要检测 4:只需要遍历每一个节点,判断节点和孩子是否同时为红。
要检测 5:我们只需要将叶子节点(这里指没有孩子的节点)保存起来,然后根据 parent 指针依次回溯到根节点,判断黑色节点数是否相同。
这就够了吗?下面看几种仅上面方法无法检测的情况:
-
红节点 p 只有一个孩子(当然是黑色),可见 p 违背规则 5
-
黑节点 p 仅有一个孩子且该孩子是黑色,可见 p 违背规则 5
-
黑节点 p 仅有一个孩子且该孩子是红色,但该红孩子有两个黑孩子,可见 p 违背规则 5
代码中定义接口 bool isRBTree()
用于判断红黑树是否正确,你可以结合上面的讲解和代码进行分析。
控制台打印红黑树
为了格式美观,数据均为两位数 10-99,空节点为 -1,如此,一个数据占据高宽均为两个英文字符占据的大小,由于控制台是黑色的(虽然可以改变),所以打印黑色节点时用绿色代替,void levelTest() const
是输出测试接口。
代码中使用的 \33[xxxm
转义,有兴趣的可以自行查询。
笔者使用环境:Win10 + VS2019,若读者使用环境不同,笔者不敢保证 \33[xxxm
转义有效。
看看效果:
最后贴出源码。
红黑树 C++ 源码
#include <iostream>
#include <ctime>
#include <stack>
#include <queue>
#include <functional>
// 不允许键值相同的节点同时存在
template<typename _Ty>
class RBTree
{
private:
enum Color :bool { RED = 0, BLACK };
struct Node
{
_Ty key;
Node* left = nullptr;
Node* right = nullptr;
Node* parent = nullptr;
Color color = RED; // 新插入节点为红色
Node(const _Ty& _key) :key(_key) {}
};
public:
RBTree() = default;
~RBTree() { clear(); }
// 递归先序遍历,中序遍历,后续遍历
void preorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const;
void inorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const;
void postorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const;
// 一个参数版本
void preorderTraversal(const std::function<void(const _Ty&)>& _func) const { preorderTraversal(root, _func); }
void inorderTraversal(const std::function<void(const _Ty&)>& _func) const { inorderTraversal(root, _func); }
void postorderTraversal(const std::function<void(const _Ty&)>& _func) const { postorderTraversal(root, _func); }
// 默认打印数据
void preorderTraversal() const { preorderTraversal(root, drawData); }
void inorderTraversal() const { inorderTraversal(root, drawData); }
void postorderTraversal() const { postorderTraversal(root, drawData); }
// 迭代先序遍历,中序遍历,后续遍历
void iterativePreorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const;
void iterativeInorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const;
void iterativePostorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const;
// 一个参数版本
void iterativePreorderTraversal(const std::function<void(const _Ty&)>& _func) const { iterativePreorderTraversal(root, _func); }
void iterativeInorderTraversal(const std::function<void(const _Ty&)>& _func) const { iterativeInorderTraversal(root, _func); }
void iterativePostorderTraversal(const std::function<void(const _Ty&)>& _func) const { iterativePostorderTraversal(root, _func); }
// 默认打印数据
void iterativePreorderTraversal() const { iterativePreorderTraversal(root, drawData); }
void iterativeInorderTraversal() const { iterativeInorderTraversal(root, drawData); }
void iterativePostorderTraversal() const { iterativePostorderTraversal(root, drawData); }
// 层序遍历
void levelTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const;
// 一个参数版本
void levelTraversal(const std::function<void(const _Ty&)>& _func) const { levelTraversal(root, _func); }
// 默认打印数据
void levelTraversal() const { levelTraversal(root, drawData); }
// 查找
Node* find(const _Ty& _key, Node* _root) const;
Node* iterativeFind(const _Ty& _key, Node* _root) const;
// 一个参数版本
Node* find(const _Ty& _key) const { return find(_key, root); }
Node* iterativeFind(const _Ty& _key) const { return iterativeFind(_key, root); }
// 查找最大或最小值
Node* findMaximum(Node* _root) const;
Node* findMinimum(Node* _root) const;
// 无参数版本
Node* findMaximum() const { return findMaximum(root); }
Node* findMinimum() const { return findMinimum(root); }
// 获取层数
size_t getLevel(const Node* _root) const;
size_t iterativeGetLevel(const Node* _root) const;
// 无参数版本
size_t getLevel() const { return getLevel(root); }
size_t iterativeGetLevel() const { return iterativeGetLevel(root); }
// 获取元素个数
size_t size() const { return size_n; }
// 是否为空
bool empty() const { return size_n == 0; }
// 插入数据(返回 pair,second 指示是否插入成功,为 true 代表成功返回其节点指针,为 false 代表失败返回 nullptr)
// 返回类型可简写用 auto 代替
std::pair<Node*, bool> insert(const _Ty& _key);
// 删除
bool erase(const _Ty& _key);
void erase(Node* _node);
// 清空树
void clear() noexcept { clear(root); root = nullptr; size_n = 0; }
// 打印数据
static void drawData(const _Ty& val) { std::cout << val << " "; };
// 测试
void levelTest() const
{
if (root == nullptr) return;
std::queue<const Node*> nodePointers;
nodePointers.push(root);
int level = getLevel();
char buff[256] = { 0 };
while (!nodePointers.empty() && level--)
{
const Node* levelBegin = nodePointers.front();
const Node* levelEnd = nodePointers.back();
const Node* cur = levelBegin;
bool first = true;
while (true)
{
cur->left != nullptr ? nodePointers.push(cur->left) : nodePointers.push(new Node(-1));
cur->right != nullptr ? nodePointers.push(cur->right) : nodePointers.push(new Node(-1));
if (first)
{
sprintf_s(buff, "\33[%dC", static_cast<int>(pow(2, level)) * 2);
std::cout << buff;
first = false;
}
else
{
sprintf_s(buff, "\33[%dC", static_cast<int>(pow(2, level + 1)) * 2);
std::cout << buff << "\b\b";
}
if (cur == levelEnd) break;
std::cout << (cur->key == -1 ? "\33[1;30;40m" : (cur->color == RED ? "\33[1;31;40m" : "\33[1;32;40m")) << cur->key << "\33[0m";
if (nodePointers.front()->key == -1) delete nodePointers.front();
nodePointers.pop();
cur = nodePointers.front();
}
std::cout << (cur->key == -1 ? "\33[1;30;40m" : (cur->color == RED ? "\33[1;31;40m" : "\33[1;32;40m")) << cur->key << "\33[0m" << std::endl << std::endl;
if (nodePointers.front()->key == -1) delete nodePointers.front();
nodePointers.pop();
}
}
// 检测 RB 树是否正确
bool isRBTree()
{
if (root == nullptr) return true;
else if (root->color == RED) return false;
bool isRight = true;
std::stack<const Node*> nodePointers;
std::stack<const Node*> nodePointers2;
const Node* cur = root;
while (cur != nullptr || !nodePointers.empty())
{
if (cur->color == RED)
{
if (cur->left != nullptr && cur->left->color == RED) return false;
else if (cur->right != nullptr && cur->right->color == RED) return false;
else if (cur->left != nullptr && cur->right == nullptr
|| cur->left == nullptr && cur->right != nullptr) return false;
}
else if (cur->left != nullptr && cur->left->color == BLACK && cur->right == nullptr) return false;
else if (cur->right != nullptr && cur->right->color == BLACK && cur->left == nullptr) return false;
else if (cur->right == nullptr && cur->left != nullptr && cur->left->color == RED && cur->left->left != nullptr && cur->left->right != nullptr) return false;
else if (cur->left == nullptr && cur->right != nullptr && cur->right->color == RED && cur->right->left != nullptr && cur->right->right != nullptr) return false;
if (cur->left == nullptr && cur->right == nullptr)
nodePointers2.push(cur);
nodePointers.push(cur);
cur = cur->left;
while (cur == nullptr && !nodePointers.empty())
{
cur = nodePointers.top()->right;
nodePointers.pop();
}
}
size_t blackNodes = 0;
while (!nodePointers2.empty())
{
cur = nodePointers2.top();
nodePointers2.pop();
size_t tempNums = 0;
while (cur != root)
{
if (cur->color == BLACK) ++tempNums;
cur = cur->parent;
}
if (blackNodes == 0) blackNodes = tempNums;
else if (tempNums != blackNodes) return false;
}
return isRight;
}
private:
// _node 指失衡节点
Node* rightRotate(Node* _node); // 左左_右旋
Node* leftRotate(Node* _node); // 右右_左旋
// 清空树递归调用
void clear(Node* _root) noexcept;
private:
Node* root = nullptr;
size_t size_n = 0;
};
template<typename _Ty>
void RBTree<_Ty>::preorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const
{
if (_root == nullptr) return;
_func(_root->key);
preorderTraversal(_root->left, _func);
preorderTraversal(_root->right, _func);
}
template<typename _Ty>
void RBTree<_Ty>::inorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const
{
if (_root == nullptr) return;
inorderTraversal(_root->left, _func);
_func(_root->key);
inorderTraversal(_root->right, _func);
}
template<typename _Ty>
void RBTree<_Ty>::postorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const
{
if (_root == nullptr) return;
postorderTraversal(_root->left, _func);
postorderTraversal(_root->right, _func);
_func(_root->key);
}
template<typename _Ty>
void RBTree<_Ty>::iterativePreorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const
{
if (_root == nullptr) return;
std::stack<const Node*> nodePointers;
const Node* cur = _root;
while (cur != nullptr || !nodePointers.empty())
{
_func(cur->key);
nodePointers.push(cur);
cur = cur->left;
while (cur == nullptr && !nodePointers.empty())
{
cur = nodePointers.top()->right;
nodePointers.pop();
}
}
}
template<typename _Ty>
void RBTree<_Ty>::iterativeInorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const
{
if (_root == nullptr) return;
std::stack<const Node*> nodePointers;
const Node* cur = _root;
while (cur != nullptr || !nodePointers.empty())
{
if (cur->left != nullptr)
{
nodePointers.push(cur);
cur = cur->left;
}
else
{
_func(cur->key);
cur = cur->right;
while (cur == nullptr && !nodePointers.empty())
{
cur = nodePointers.top();
nodePointers.pop();
_func(cur->key);
cur = cur->right;
}
}
}
}
template<typename _Ty>
void RBTree<_Ty>::iterativePostorderTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const
{
if (_root == nullptr) return;
std::stack<const Node*> nodePointers;
nodePointers.push(_root);
const Node* last = nullptr;
const Node* cur = nullptr;
while (!nodePointers.empty())
{
cur = nodePointers.top();
if ((cur->left == nullptr && cur->right == nullptr) || last != nullptr && (last == cur->left || last == cur->right))
{
_func(cur->key);
nodePointers.pop();
last = cur;
}
else
{
if (cur->right != nullptr) nodePointers.push(cur->right);
if (cur->left != nullptr) nodePointers.push(cur->left);
}
}
}
template<typename _Ty>
void RBTree<_Ty>::levelTraversal(const Node* _root, const std::function<void(const _Ty&)>& _func) const
{
if (_root == nullptr) return;
std::queue<const Node*> nodePointers;
nodePointers.push(_root);
const Node* cur = nullptr;
while (!nodePointers.empty())
{
cur = nodePointers.front();
_func(cur->key);
if (cur->left != nullptr) nodePointers.push(cur->left);
if (cur->right != nullptr) nodePointers.push(cur->right);
nodePointers.pop();
}
}
template<typename _Ty>
typename RBTree<_Ty>::Node* RBTree<_Ty>::find(const _Ty& _key, Node* _root) const
{
if (_root == nullptr || _key == _root->key) return _root;
if (_key < _root->key) return find(_key, _root->left);
else return find(_key, _root->right);
}
template<typename _Ty>
typename RBTree<_Ty>::Node* RBTree<_Ty>::iterativeFind(const _Ty& _key, Node* _root) const
{
while (_root != nullptr && _root->key != _key)
_key < _root->key ? _root = _root->left : _root = _root->right;
return _root;
}
template<typename _Ty>
typename RBTree<_Ty>::Node* RBTree<_Ty>::findMaximum(Node* _root) const
{
if (_root == nullptr) return _root;
while (_root->right != nullptr)
_root = _root->right;
return _root;
}
template<typename _Ty>
typename RBTree<_Ty>::Node* RBTree<_Ty>::findMinimum(Node* _root) const
{
if (_root == nullptr) return _root;
while (_root->left != nullptr)
_root = _root->left;
return _root;
}
template<typename _Ty>
size_t RBTree<_Ty>::getLevel(const Node* _root) const
{
if (_root == nullptr) return 0;
if (_root->left == nullptr && _root->right == nullptr) return 1;
return 1 + getLevel(_root->left) > 1 + getLevel(_root->right) ? 1 + getLevel(_root->left) : 1 + getLevel(_root->right);
}
template<typename _Ty>
size_t RBTree<_Ty>::iterativeGetLevel(const Node* _root) const
{
if (_root == nullptr) return 0;
size_t level = 0;
std::queue<const Node*> nodePointers;
nodePointers.push(_root);
while (!nodePointers.empty())
{
const Node* levelBegin = nodePointers.front();
const Node* levelEnd = nodePointers.back();
const Node* cur = levelBegin;
while (true)
{
if (cur->left != nullptr) nodePointers.push(cur->left);
if (cur->right != nullptr) nodePointers.push(cur->right);
if (cur == levelEnd) break;
nodePointers.pop();
cur = nodePointers.front();
}
nodePointers.pop();
++level;
}
return level;
}
template<typename _Ty>
std::pair<typename RBTree<_Ty>::Node*, bool> RBTree<_Ty>::insert(const _Ty& _key)
{
// 以二叉查找树的方式插入
std::pair<Node*, bool> ret(nullptr, true);
if (root == nullptr)
{
++size_n;
root = new Node(_key);
root->color = BLACK;
ret.first = root;
return ret;
}
Node* cur = root;
while (true)
{
if (_key == cur->key)
{
ret.second = false;
return ret;
}
else if (_key < cur->key && cur->left != nullptr) cur = cur->left;
else if (_key > cur->key && cur->right != nullptr) cur = cur->right;
else break;
}
++size_n;
Node* newNode = new Node(_key);
_key < cur->key ? cur->left = newNode : cur->right = newNode;
newNode->parent = cur;
ret.first = newNode;
cur = newNode;
if (cur->parent->color == BLACK) return ret;
// 调整
while (cur->parent != nullptr && cur->parent->color == RED && cur->parent != root)
{
bool parIsRight = cur->parent == cur->parent->parent->right ? true : false;
auto uncle = parIsRight ? cur->parent->parent->left : cur->parent->parent->right;
if (uncle != nullptr && uncle->color == RED)
{
cur->parent->color = BLACK;
uncle->color = BLACK;
cur = cur->parent->parent;
cur->color = RED;
}
else
{
if (parIsRight)
{
if (cur == cur->parent->left) cur = rightRotate(cur->parent)->right;
cur = leftRotate(cur->parent->parent);
cur->left->color = RED;
cur->color = BLACK;
}
else
{
if (cur == cur->parent->right) cur = leftRotate(cur->parent)->left;
cur = rightRotate(cur->parent->parent);
cur->right->color = RED;
cur->color = BLACK;
}
return ret;
}
}
root->color = BLACK;
return ret;
}
template<typename _Ty>
bool RBTree<_Ty>::erase(const _Ty& _key)
{
bool succeed = false;
Node* del = find(_key);
if (del != nullptr)
{
erase(del);
succeed = true;
}
return succeed;
}
template<typename _Ty>
void RBTree<_Ty>::erase(Node* _node)
{
if (_node == nullptr) return;
--size_n;
if (_node->color == RED)
{
if (_node->left == nullptr && _node->right == nullptr)
{
_node == _node->parent->left ? _node->parent->left = nullptr : _node->parent->right = nullptr;
delete _node;
}
else
{
auto del = findMinimum(_node->right);
_node->key = del->key;
++size_n;
erase(del);
}
}
else
{
if (_node->right == nullptr && _node->left == nullptr)
{
Node* par = _node->parent;
if (par == nullptr)
{
delete root;
root = nullptr;
return;
}
bool isLeft = _node == par->left ? true : false;
isLeft ? par->left = nullptr : par->right = nullptr;
delete _node;
while (true)
{
Node* bro = isLeft ? par->right : par->left;
if (bro->color == RED)
{
isLeft ? leftRotate(par) : rightRotate(par);
par->color = RED;
bro->color = BLACK;
if (bro->parent == nullptr) root = bro;
bro = isLeft ? par->right : par->left;
}
if (bro->color == BLACK)
{
if ((isLeft ? bro->left : bro->right) != nullptr && (isLeft ? bro->left : bro->right)->color == RED)
{
Node* temp = isLeft ? bro->left : bro->right;
isLeft ? rightRotate(bro) : leftRotate(bro);
isLeft ? leftRotate(par) : rightRotate(par);
temp->color = par->color;
par->color = BLACK;
return;
}
else if ((isLeft ? bro->right : bro->left) != nullptr && (isLeft ? bro->right : bro->left)->color == RED)
{
Node* temp = isLeft ? bro->right : bro->left;
isLeft ? leftRotate(par) : rightRotate(par);
bro->color = par->color;
par->color = BLACK;
temp->color = BLACK;
return;
}
else if (par->color == RED)
{
par->color = BLACK;
bro->color = RED;
return;
}
else
{
bro->color = RED;
if (par->parent != nullptr) isLeft = (par == par->parent->left ? true : false);
else return;
par = par->parent;
}
}
}
}
else if (!(_node->right != nullptr && _node->left != nullptr))
{
Node* del;
_node->right != nullptr ? del = _node->right : del = _node->left;
_node->right != nullptr ? _node->right = nullptr : _node->left = nullptr;
_node->key = del->key;
delete del;
}
else
{
auto del = findMinimum(_node->right);
_node->key = del->key;
++size_n;
erase(del);
}
}
}
template<typename _Ty>
void RBTree<_Ty>::clear(Node* _root) noexcept
{
if (_root == nullptr) return;
clear(_root->left);
clear(_root->right);
delete _root;
}
template<typename _Ty>
typename RBTree<_Ty>::Node* RBTree<_Ty>::rightRotate(Node* _node)
{
Node* leftNode = _node->left;
if (_node->parent != nullptr)
_node->parent->left == _node ? _node->parent->left = leftNode : _node->parent->right = leftNode;
leftNode->parent = _node->parent;
_node->left = leftNode->right;
if (leftNode->right != nullptr) leftNode->right->parent = _node;
_node->parent = leftNode;
leftNode->right = _node;
if (leftNode->parent == nullptr) root = leftNode;
return leftNode;
}
template<typename _Ty>
typename RBTree<_Ty>::Node* RBTree<_Ty>::leftRotate(Node* _node)
{
Node* rightNode = _node->right;
if (_node->parent != nullptr)
_node->parent->right == _node ? _node->parent->right = rightNode : _node->parent->left = rightNode;
rightNode->parent = _node->parent;
_node->right = rightNode->left;
if (rightNode->left != nullptr) rightNode->left->parent = _node;
_node->parent = rightNode;
rightNode->left = _node;
if (rightNode->parent == nullptr) root = rightNode;
return rightNode;
}
int main()
{
srand((unsigned)time(nullptr));
std::cout.setf(std::ios_base::boolalpha);
// 随机生成元素用于插入
constexpr int LEN = 20;
int arr[LEN];
for (auto& val : arr) val = rand() % 90 + 10;
std::cout << "待插入数据(将跳过重复数据):" << std::endl << std::endl;
for (auto& val : arr) std::cout << val << ",";
std::cout << std::endl << std::endl << std::endl;
// 创建空树
RBTree<int> avlt;
// 插入
for (auto& val : arr)
{
if (avlt.find(val) != nullptr) continue;
avlt.insert(val);
std::cout << "Insert:" << val << std::endl << std::endl;
avlt.levelTest();
bool isRight = avlt.isRBTree();
std::cout << "Level = " << avlt.getLevel() << ", Size = " << avlt.size() << ", IsRBTree:" << isRight << std::endl << std::endl << std::endl;
if (!isRight) return -1;
}
// 删除
for (auto& val : arr)
{
if (avlt.find(val) == nullptr) continue;
avlt.erase(val);
std::cout << "Erase:" << val << std::endl << std::endl;
avlt.levelTest();
bool isRight = avlt.isRBTree();
std::cout << "Level = " << avlt.getLevel() << ", Size = " << avlt.size() << ", IsRBTree:" << isRight << std::endl << std::endl << std::endl;
if (!isRight) return -2;
}
return 0;
}