Day-10-高级数据结构(Trie树、并查集、线段树)Leetcode-208, 211, 547, 307


//Trie 树节点表示(字典树)
#include <stdio.h>
# define TRIE_MAX_CHAR_NUM 26
struct TrieNode {
       TrieNode* child[TRIE_MAX_CHAR_NUM];  // Trie 树的指针数组
       bool is_end;
       TrieNode() : is_end(false) {
              for (int i = 0; i < TRIE_MAX_CHAR_NUM; i++) {
                     child[i] = 0;
              }
       }
};
void preorder_trie(TrieNode* node, int layer) {  // Trie树的前序遍历
       for (int i = 0; i < TRIE_MAX_CHAR_NUM; i++) {
              if (node->child[i]) {
                     for (int j = 0; j < layer; j++) {
                           printf("---");
                     }
                     printf("%c", i + 'a');
                     if (node->child[i]->is_end) {
                           printf("(end)");
                     }
                     printf("\n");
                     preorder_trie(node->child[i], layer + 1);
              }
       }
}
int main() {
       TrieNode root;
       TrieNode n1;
       TrieNode n2;
       TrieNode n3;
       root.child['a' - 'a'] = &n1;
       root.child['b' - 'a'] = &n2;
       root.child['e' - 'a'] = &n3;
       n2.is_end = true;
       TrieNode n4;
       TrieNode n5;
       TrieNode n6;
       n1.child['b' - 'a'] = &n4;
       n2.child['c' - 'a'] = &n5;
       n3.child['f' - 'a'] = &n6;
       TrieNode n7;
       TrieNode n8;
       TrieNode n9;
       TrieNode n10;
       n4.child['c' - 'a'] = &n7;
       n4.child['d' - 'a'] = &n8;
       n5.child['d' - 'a'] = &n9;
       n6.child['g' - 'a'] = &n10;
       n7.is_end = true;
       n8.is_end = true;
       n9.is_end = true;
       n10.is_end = true;
       TrieNode n11;
       n7.child['d' - 'a'] = &n11;
       n11.is_end = true;
       preorder_trie(&root, 0);  // 字典树的前序遍历
       return 0;
}


// Trie 树获取全部单词
#include <stdio.h>
#include <string>
#include <vector>
#define TIRE_MAX_CHAR_NUM 26
struct TrieNode {
       TrieNode* child[TIRE_MAX_CHAR_NUM];
       bool is_end;
       TrieNode() : is_end(false) {
              for (int i = 0; i < TIRE_MAX_CHAR_NUM; i++) {
                     child[i] = 0;
              }
       }
};
void get_all_word_from_trie(TrieNode* node, std::string word, std::vector<std::string>& word_list) {
       for (int i = 0; i < TIRE_MAX_CHAR_NUM; i++) {
              if (node->child[i]) {
                     word.push_back(i + 'a');  // string也可以用push_back
                     if (node->child[i]->is_end) {
                           word_list.push_back(word);
                     }
                     get_all_word_from_trie(node->child[i], word, word_list);
                     word.erase(word.length() - 1, 1);  // 弹出最后一个字符
              }
       }
}
int main() {
       TrieNode root;
       TrieNode n1;
       TrieNode n2;
       TrieNode n3;
       root.child['a' - 'a'] = &n1;
       root.child['b' - 'a'] = &n2;
       root.child['e' - 'a'] = &n3;
       n2.is_end = true;
       TrieNode n4;
       TrieNode n5;
       TrieNode n6;
       n1.child['b' - 'a'] = &n4;
       n2.child['c' - 'a'] = &n5;
       n3.child['f' - 'a'] = &n6;
       TrieNode n7;
       TrieNode n8;
       TrieNode n9;
       TrieNode n10;
       n4.child['c' - 'a'] = &n7;
       n4.child['d' - 'a'] = &n8;
       n5.child['d' - 'a'] = &n9;
       n6.child['g' - 'a'] = &n10;
       n7.is_end = true;
       n8.is_end = true;
       n9.is_end = true;
       n10.is_end = true;
       TrieNode n11;
       n7.child['d' - 'a'] = &n11;
       n11.is_end = true;
       std::vector<std::string> word_list;
       std::string word;
       get_all_word_from_trie(&root, word, word_list);  // root 要加' & '
       for (int i = 0; i < word_list.size(); i++) {
              printf("%s\n", word_list[i].c_str());
       }
       return 0;
}


// Trie 树的整体功能 1. 将word插入trie 2. 搜索trie中是否存在word 3. 确认trie中是否有前缀为prefix的单词
#include <stdio.h>
#include <vector>
#include <string>
#define TRIE_MAX_CHAR_NUM 26
struct TrieNode {
       TrieNode* child[TRIE_MAX_CHAR_NUM];
       bool is_end;
       TrieNode() : is_end(false) {
              for (int i = 0; i < TRIE_MAX_CHAR_NUM; i++) {
                     child[i] = 0;
              }
       }
};
class TrieTree {
public:
       TrieTree() {}
       ~TrieTree() {
              for (int i = 0; i < _node_vec.size(); i++) {
                     delete _node_vec[i];
              }
       }
       void insert(const char* word) {  // Trie 插入 word
              TrieNode* ptr = &_root;
              while (*word) {
                     int pos = *word - 'a';
                     if (!ptr->child[pos]) {  // word 是地址,*word是地址里的内容
                           ptr->child[pos] = new_node();
                     }
                     ptr = ptr->child[pos];
                     word++;
              }
              ptr->is_end = true;
       }
       bool search(const char* word) {  // Trie 树搜索 word
              TrieNode* ptr = &_root;
              while (*word) {
                     int pos = *word - 'a';
                     if (!ptr->child[pos]) {
                           return false;
                     }
                     ptr = ptr->child[pos];
                     word++;
              }
              return ptr->is_end;
       }
       bool startsWith(const char* prefix) {  // Trie 树前缀查询
              TrieNode* ptr = &_root;
              while (*prefix) {
                     int pos = *prefix - 'a';
                     if (!ptr->child[pos]) {
                           return false;
                     }
                     ptr = ptr->child[pos];
                     prefix++;
              }
              return true;
       }
       TrieNode* root() {
              return &_root;
       }
private:
       TrieNode* new_node() {
              TrieNode* node = new TrieNode();
              _node_vec.push_back(node);
              return node;
       }
       std::vector<TrieNode*> _node_vec;
       TrieNode _root;
};
void preorder_trie(TrieNode* node, int layer) {  //Trie 树前序遍历
       for (int i = 0; i < TRIE_MAX_CHAR_NUM; i++) {
              if (node->child[i]) {
                     for (int j = 0; j < layer; j++) {
                           printf("---");
                     }
                     printf("%c", i + 'a');
                     if (node->child[i]->is_end) {
                           printf("(end)");
                     }
                     printf("\n");
                     preorder_trie(node->child[i], layer + 1);
              }
       }
}
void get_all_word_from_trie(TrieNode* node, std::string word, std::vector<std::string>& word_list) {
       for (int i = 0; i < TRIE_MAX_CHAR_NUM; i++) {
              if (node->child[i]) {
                     word.push_back(i + 'a');  // string也可以用push_back
                     if (node->child[i]->is_end) {
                           word_list.push_back(word);
                     }
                     get_all_word_from_trie(node->child[i], word, word_list);
                     word.erase(word.length() - 1, 1);  // 弹出最后一个字符
              }
       }
}
int main() {
       TrieTree trie_tree;
       trie_tree.insert("abcd");
       trie_tree.insert("abc");
       trie_tree.insert("abd");
       trie_tree.insert("b");
       trie_tree.insert("bcd");
       trie_tree.insert("efg");
       printf("preorder_trie:\n");
       preorder_trie(trie_tree.root(), 0);
       printf("\n");
       std::vector<std::string> word_list;
       std::string word;
       printf("All words:\n");
       get_all_word_from_trie(trie_tree.root(), word, word_list);
       for (int i = 0; i < word_list.size(); i++) {
              printf("%s\n", word_list[i].c_str());
       }
       printf("\n");
       printf("search:\n");
       printf("abc: %d\n", trie_tree.search("abc"));
       printf("abcd: %d\n", trie_tree.search("abcd"));
       printf("bc: %d\n", trie_tree.search("bc"));
       printf("b: %d\n", trie_tree.search("b"));
       printf("\n");
       printf("ab: %d\n", trie_tree.startsWith("ab"));
       printf("abc: %d\n", trie_tree.startsWith("abc"));
       printf("bc: %d\n", trie_tree.startsWith("bc"));
       printf("fg: %d\n", trie_tree.startsWith("fg"));
       return 0;
};

例一:LeetCode208

/**
 Implement a trie with insert, search, and startsWith methods.
 */
#include <stdio.h>
#include <vector>
#include <string>
#define TRIE_MAX_CHAR_NUM 26
struct TrieNode {
       TrieNode* child[TRIE_MAX_CHAR_NUM];
       bool is_end;
       TrieNode() : is_end(false) {
              for (int i = 0; i < TRIE_MAX_CHAR_NUM; i++) {
                     child[i] = 0;
              }
       }
};
class TrieTree {
public:
       TrieTree() {}
       ~TrieTree() {
              for (int i = 0; i < _node_vec.size(); i++) {
                     delete _node_vec[i];
              }
       }
       // 传入 const char * 最快,代表了string的首地址,或者 string &word也快
       void insert(const char* word) {  // Trie 插入 word
              TrieNode* ptr = &_root;
              while (*word) {
                     int pos = *word - 'a';
                     if (!ptr->child[pos]) {  // word 是地址,*word是地址里的内容
                           ptr->child[pos] = new_node();
                     }
                     ptr = ptr->child[pos];
                     word++;
              }
              ptr->is_end = true;
       }
       bool search(const char* word) {  // Trie 树搜索 word
              TrieNode* ptr = &_root;
              while (*word) {
                     int pos = *word - 'a';
                     if (!ptr->child[pos]) {
                           return false;
                     }
                     ptr = ptr->child[pos];
                     word++;
              }
              return ptr->is_end;
       }
       bool startsWith(const char* prefix) {  // Trie 树前缀查询
              TrieNode* ptr = &_root;
              while (*prefix) {
                     int pos = *prefix - 'a';
                     if (!ptr->child[pos]) {
                           return false;
                     }
                     ptr = ptr->child[pos];
                     prefix++;
              }
              return true;
       }
       TrieNode* root() {
              return &_root;
       }
private:
       TrieNode* new_node() {
              TrieNode* node = new TrieNode();
              _node_vec.push_back(node);
              return node;
       }
       std::vector<TrieNode*> _node_vec;
       TrieNode _root;
};
class Trie {
public:
       Trie() {}
       void insert(std::string word) {
              _trie_tree.insert(word.c_str());
       }
       bool search(std::string word) {
              return _trie_tree.search(word.c_str());
       }
       bool startsWith(std::string prefix) {
              return _trie_tree.startsWith(prefix.c_str());
       }
private:
       TrieTree _trie_tree;
};
int main() {
       Trie trie;
       trie.insert("abcde");
       printf("%d\n", trie.search("abcde"));
       printf("%d\n", trie.startsWith("abc"));
       printf("%d\n", trie.startsWith("abcdef"));
       printf("%d\n", trie.startsWith("abcde"));
       return 0;
}

例二:LeetCode211



/**
 Design a data structure that supports the following two operations:
void addWord(word)
bool search(word)
search(word) can search a literal word or a regular expression string
containing only letters a-z or .. A . means it can represent any one letter.
 */
#include <stdio.h>
#include <vector>
#include <string>
#define TRIE_MAX_CHAR_NUM 26
struct TrieNode {
       TrieNode* child[TRIE_MAX_CHAR_NUM];
       bool is_end;
       TrieNode() : is_end(false) {
              for (int i = 0; i < TRIE_MAX_CHAR_NUM; i++) {
                     child[i] = 0;
              }
       }
};
class TrieTree {
public:
       TrieTree() {}
       ~TrieTree() {
              for (int i = 0; i < _node_vec.size(); i++) {
                     delete _node_vec[i];
              }
       }
       void insert(const char* word) {  // Trie 插入 word
              TrieNode* ptr = &_root;
              while (*word) {
                     int pos = *word - 'a';
                     if (!ptr->child[pos]) {  // word 是地址,*word是地址里的内容
                           ptr->child[pos] = new_node();
                     }
                     ptr = ptr->child[pos];
                     word++;
              }
              ptr->is_end = true;
       }
       bool search(const char* word) {  // Trie 树搜索 word
              TrieNode* ptr = &_root;
              while (*word) {
                     int pos = *word - 'a';
                     if (!ptr->child[pos]) {
                           return false;
                     }
                     ptr = ptr->child[pos];
                     word++;
              }
              return ptr->is_end;
       }
       bool search_trie(TrieNode* node, const char* word) {
              if (*word == '\0') {
                     if (node->is_end) {
                           return true;
                     }
                     return false;
              }
              if (*word == '.') {
                     for (int i = 0; i < TRIE_MAX_CHAR_NUM; i++) {
                           if (node->child[i] && search_trie(node->child[i], word + 1)) {
                                  return true;
                           }
                     }
              }
              else {
                     int pos = *word - 'a';
                     if (node->child[pos] && search_trie(node->child[pos], word + 1)) {
                           return true;
                     }
              }
              return false;
       }
       bool startsWith(const char* prefix) {  // Trie 树前缀查询
              TrieNode* ptr = &_root;
              while (*prefix) {
                     int pos = *prefix - 'a';
                     if (!ptr->child[pos]) {
                           return false;
                     }
                     ptr = ptr->child[pos];
                     prefix++;
              }
              return true;
       }
       TrieNode* root() {
              return &_root;
       }
private:
       TrieNode* new_node() {
              TrieNode* node = new TrieNode();
              _node_vec.push_back(node);
              return node;
       }
       std::vector<TrieNode*> _node_vec;
       TrieNode _root;
};
class WordDictionary {
public:
       WordDictionary() {}
       void addWord(std::string word) {
              _tire_tree.insert(word.c_str());
       }
       bool search(std::string word) {
              return _tire_tree.search_trie(_tire_tree.root(), word.c_str());  // c_str() 返回的是字符串的首字符地址
       }
private:
       TrieTree _tire_tree;
};
int main() {
       WordDictionary word_dictionary;
       word_dictionary.addWord("abc");
       word_dictionary.addWord("bad");
       word_dictionary.addWord("mad");
       printf("%d\n", word_dictionary.search("pad"));
       printf("%d\n", word_dictionary.search("bad"));
       printf("%d\n", word_dictionary.search(".ad"));
       printf("%d\n", word_dictionary.search("b.."));
    
   return 0;
}

例三:LeetCode547


class Solution {
public:
       int findCircleNum(std::vector<std::vector<int>>& M) {
              std::vector<int> visit(M.size(), 0);
              int count = 0;
              for (int i = 0; i < M.size(); i++) {
                     if (visit[i] == 0) {
                           DFS_graph(i, M, visit);
                           count++;
                     }
              }
              return count;
       }
private:
       void DFS_graph(int u, std::vector<std::vector<int>>& graph, std::vector<int>& visit) {
              visit[u] = 1;
              for (int i = 0; i < graph[u].size(); i++) {
                     if (visit[i] == 0 && graph[u][i] == 1) {
                           DFS_graph(i, graph, visit);
                     }
              }
       }
};



#include <stdio.h>
#include <vector>
#include <string>
// // 方法一:深度搜索算法
// class Solution{
// public:
//     int findCircleNum(std::vector<std::vector<int>> &M){
//         std::vector<int> visit(M.size(), 0);
//         int count = 0;
//         for(int i = 0; i < M.size(); i++){
//             if(visit[i] == 0){
//                 DFS_graph(i, M, visit);
//                 count++;
//             }
//         }
//         return count;
//     }
// private:
//     void DFS_graph(int u, std::vector<std::vector<int>> &graph, std::vector<int> &visit){
//         visit[u] = 1;
//         for(int i = 0; i < graph[u].size(); i++) {
//             if (visit[i] == 0 && graph[u][i] == 1) {
//                 DFS_graph(i, graph, visit);
//             }
//         }
//     }
// };
// 数组实现并查集
class DisjoinSet {
public:
       DisjoinSet(int n) {
              for (int i = 0; i < n; i++) {
                     _id.push_back(i);
              }
       }
       int find(int p) {
              return _id[p];
       }
       void union_(int p, int q) {  // 合并 O(n)
              int pid = find(p);
              int qid = find(q);
              if (pid == qid) {
                     return;
              }
              for (int i = 0; i < _id.size(); i++) {
                     if (_id[i] == pid) {
                           _id[i] = qid;
                     }
              }
       }
       void print_set() {
              printf("元素:");
              for (int i = 0; i < _id.size(); i++) {
                     printf("%d ", i);
              }
              printf("\n");
              printf("集合: ");
              for (int i = 0; i < _id.size(); i++) {
                     printf("%d ", _id[i]);
              }
              printf("\n");
       }
private:
       std::vector<int> _id;
};
int main() {
       DisjoinSet disjoint_set(8);
       disjoint_set.print_set();
       printf("union(0, 5):\n");
       disjoint_set.union_(0, 5);
       disjoint_set.print_set();
       printf("Find(0) = %d, Find(5) = %d\n", disjoint_set.find(0), disjoint_set.find(5));
       printf("Find(2) = %d, Find(5) = %d\n", disjoint_set.find(2), disjoint_set.find(5));
       disjoint_set.union_(2, 4);
       disjoint_set.print_set();
       disjoint_set.union_(0, 4);
       disjoint_set.print_set();
       printf("Find(2) = %d, Find(5) = %d\n", disjoint_set.find(2), disjoint_set.find(5));
       return 0;
}



// 并查集的实现
#include <stdio.h>
#include <vector>
#include <string>
 // // 方法一:深度搜索算法
 // class Solution{
 // public:
 //     int findCircleNum(std::vector<std::vector<int>> &M){
 //         std::vector<int> visit(M.size(), 0);
 //         int count = 0;
 //         for(int i = 0; i < M.size(); i++){
 //             if(visit[i] == 0){
 //                 DFS_graph(i, M, visit);
 //                 count++;
 //             }
 //         }
 //         return count;
 //     }
 // private:
 //     void DFS_graph(int u, std::vector<std::vector<int>> &graph, std::vector<int> &visit){
 //         visit[u] = 1;
 //         for(int i = 0; i < graph[u].size(); i++) {
 //             if (visit[i] == 0 && graph[u][i] == 1) {
 //                 DFS_graph(i, graph, visit);
 //             }
 //         }
 //     }
 // };
 // // 数组实现并查集
 // class DisjoinSet{
 // public:
 //     DisjoinSet(int n){
 //         for(int i = 0; i < n; i++){
 //             _id.push_back(i);
 //         }
 //     }
 //     int find(int p){
 //         return _id[p];
 //     }
 //     void union_(int p, int q){  // 合并 O(n)
 //         int pid = find(p);
 //         int qid = find(q);
 //         if(pid == qid){
 //             return;
 //         }
 //         for(int i = 0; i < _id.size(); i++){
 //             if(_id[i] == pid){
 //                 _id[i] = qid;
 //             }
 //         }
 //     }
 //     void print_set(){
 //         printf("元素:");
 //         for(int i = 0; i < _id.size(); i++){
 //             printf("%d ", i);
 //         }
 //         printf("\n");
 //         printf("集合: ");
 //         for(int i = 0; i < _id.size(); i++){
 //             printf("%d ", _id[i]);
 //         }
 //         printf("\n");
 //     }
 // private:
 //     std::vector<int> _id;
 // };
 // 森林实现并查集
class DisjoinSet {
public:
       DisjoinSet(int n) {
              for (int i = 0; i < n; i++) {
                     _id.push_back(i);
                     _size.push_back(1);
              }
              _count = n;
       }
       int find(int p) {
              while (p != _id[p]) {
                     _id[p] = _id[_id[p]];  // 跳着查询
                     p = _id[p];
              }
              return p;
       }
       // 优化: 讲规模较小的子树合并到规模较大的子树
       void union_(int p, int q) {  // 合并:将一课子树的根节点指向另一棵子树的根节点
              int i = find(p);
              int j = find(q);
              if (i == j) {
                     return;
              }
              if (_size[i] < _size[j]) {
                     _id[i] = j;
                     _size[j] += _size[i];
              }
              else {
                     _id[j] = i;
                     _size[i] += _size[j];
              }
              _count--;
       }
       void print_set() {
              printf("元素:");
              for (int i = 0; i < _id.size(); i++) {
                     printf("%d ", i);
              }
              printf("\n");
              printf("集合: ");
              for (int i = 0; i < _id.size(); i++) {
                     printf("%d ", _id[i]);
              }
              printf("\n");
       }
private:
       std::vector<int> _id;  // 存储的是根节点
       std::vector<int> _size;  // 子树的规模
       int _count;  // 朋友圈个数
};
int main() {
       DisjoinSet disjoint_set(8);
       disjoint_set.print_set();
       printf("union(0, 5):\n");
       disjoint_set.union_(0, 5);
       disjoint_set.print_set();
       printf("Find(0) = %d, Find(5) = %d\n", disjoint_set.find(0), disjoint_set.find(5));
       printf("Find(2) = %d, Find(5) = %d\n", disjoint_set.find(2), disjoint_set.find(5));
       disjoint_set.union_(2, 4);
       disjoint_set.print_set();
       disjoint_set.union_(0, 4);
       disjoint_set.print_set();
       printf("Find(2) = %d, Find(5) = %d\n", disjoint_set.find(2), disjoint_set.find(5));
       return 0;
}

/**
 There are N students in a class. Some of them are friends, while some are not.
  Their friendship is transitive in nature. For example, if A is a direct friend
   of B, and B is a direct friend of C, then A is an indirect friend of C. And we
   defined a friend circle is a group of students who are direct or indirect friends.
Given a N*N matrix M representing the friend relationship between students in the
class. If M[i][j] = 1, then the ith and jth students are direct friends with each
 other, otherwise not. And you have to output the total number of friend circles
 among all the students.
 */
 
#include <stdio.h>
#include <vector>
#include <string>
 
// // 方法一:深度搜索算法
// class Solution{
// public:
//     int findCircleNum(std::vector<std::vector<int>> &M){
//         std::vector<int> visit(M.size(), 0);
//         int count = 0;
//         for(int i = 0; i < M.size(); i++){
//             if(visit[i] == 0){
//                 DFS_graph(i, M, visit);
//                 count++;
//             }
//         }
//         return count;
//     }
 
// private:
//     void DFS_graph(int u, std::vector<std::vector<int>> &graph, std::vector<int> &visit){
//         visit[u] = 1;
//         for(int i = 0; i < graph[u].size(); i++) {
//             if (visit[i] == 0 && graph[u][i] == 1) {
//                 DFS_graph(i, graph, visit);
//             }
//         }
//     }
// };
 
class DisjoinSet {
public:
    DisjoinSet(int n) {
        for (int i = 0; i < n; i++) {
            _id.push_back(i);
            _size.push_back(1);
        }
        _count = n;
    }
    int find(int p) {
        while (p != _id[p]) {
            _id[p] = _id[_id[p]];  // 跳着查询
            p = _id[p];
        }
        return p;
    }
    // 优化: 讲规模较小的子树合并到规模较大的子树
    void union_(int p, int q) {  // 合并:将一课子树的根节点指向另一棵子树的根节点
        int i = find(p);
        int j = find(q);
        if (i == j) {
            return;
        }
        if (_size[i] < _size[j]) {
            _id[i] = j;
            _size[j] += _size[i];
        }
        else {
            _id[j] = i;
            _size[i] += _size[j];
        }
        _count--;
    }
    void print_set() {
        printf("元素:");
        for (int i = 0; i < _id.size(); i++) {
            printf("%d ", i);
        }
        printf("\n");
        printf("集合: ");
        for (int i = 0; i < _id.size(); i++) {
            printf("%d ", _id[i]);
        }
        printf("\n");
    }
 
    int count(){
        return _count;
    }
private:
    std::vector<int> _id;  // 存储的是根节点
    std::vector<int> _size;  // 子树的规模
    int _count;  // 朋友圈个数
};
 
class Solution{
public:
    int findCircleNum(std::vector<std::vector<int>> &M){
        DisjoinSet disjoint_set(M.size());
        for(int i = 0; i < M.size(); i++){
            for(int j = i + 1; j < M.size(); j++){
                if(M[i][j]){
                    disjoint_set.union_(i, j);
                }
            }
        }
        return disjoint_set.count();
    }
};
 
int main(){
    int test[][3] = {{1, 1, 0}, {1, 1, 0}, {0, 0, 1}};
    std::vector<std::vector<int>> M(3, std::vector<int>(3, 0));
    for(int i = 0; i < 3; i++){
        for(int j = 0; j < 3; j++){
            M[i][j] = test[i][j];
        }
    }
    Solution solve;
    printf("%d\n", solve.findCircleNum(M));
    return 0;
}

例四:LeetCode307






/**
 Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.
The update(i, val) function modifies nums by updating the element at index i to val.
 */
#include <stdio.h>
#include <vector>
class NumArray {
public:
       NumArray(std::vector<int> nums) {
              if (nums.size() == 0) {
                     return;
              }
              int n = nums.size() * 4;  // 一般线段树数组的大小是原数组大小长度的4倍
              for (int i = 0; i < n; i++) {
                     _value.push_back(0);
              }
              build_segment_tree(_value, nums, 0, 0, nums.size() - 1);
              _right_end = nums.size() - 1;
       }
       void update(int i, int val) {
              update_segment_tree(_value, 0, 0, _right_end, i, val);
       }
       int sumRange(int i, int j) {
              return sum_range_segment_tree(_value, 0, 0, _right_end, i, j);
       }
private:
       std::vector<int> _value;
       int _right_end;
       // 建立线段树
       void build_segment_tree(std::vector<int>& value, std::vector<int>& nums, int pos, int left, int right) {
              if (left == right) {
                     value[pos] = nums[left];
                     return;
              }
              int mid = (left + right) / 2;
              build_segment_tree(value, nums, pos * 2 + 1, left, mid);
              build_segment_tree(value, nums, pos * 2 + 2, mid + 1, right);
              value[pos] = value[pos * 2 + 1] + value[pos * 2 + 2];
       }
       // 线段树遍历
       void print_segment_tree(std::vector<int>& value, int pos, int left, int right, int layer) {
              for (int i = 0; i < layer; i++) {
                     printf("---");
              }
              printf("[%d %d][%d]: %d\n", left, right, pos, value[pos]);
              if (left == right) {
                     return;
              }
              int mid = (left + right) / 2;
              print_segment_tree(value, pos * 2 + 1, left, mid, layer + 1);
              print_segment_tree(value, pos * 2 + 2, mid + 1, right, layer + 1);
       }
       // 线段树区间求和
       int sum_range_segment_tree(std::vector<int>& value, int pos, int left, int right, int qleft, int qright) {
              if (qleft > right || qright < left) {
                     return 0;
              }
              if (qleft <= left && qright >= right) {
                     return value[pos];
              }
              int mid = (left + right) / 2;
              return sum_range_segment_tree(value, pos * 2 + 1, left, mid, qleft, qright) +
                     sum_range_segment_tree(value, pos * 2 + 2, mid + 1, right, qleft, qright);
       }
       // 线段树更新
       void update_segment_tree(std::vector<int>& value, int pos, int left, int right, int index, int new_value) {
              if (left == right && left == index) {
                     value[pos] = new_value;
                     return;
              }
              int mid = (left + right) / 2;
              if (index <= mid) {
                     update_segment_tree(value, pos * 2 + 1, left, mid, index, new_value);
              }
              else {
                     update_segment_tree(value, pos * 2 + 2, mid + 1, right, index, new_value);
              }
              value[pos] = value[pos * 2 + 1] + value[pos * 2 + 2];
       }
};
int main() {
       std::vector<int> nums;
       nums.push_back(1);
       nums.push_back(3);
       nums.push_back(5);
       NumArray num_array(nums);
       printf("%d\n", num_array.sumRange(0, 2));
       num_array.update(1, 2);
       printf("%d\n", num_array.sumRange(0, 2));
}

猜你喜欢

转载自www.cnblogs.com/lihello/p/11520935.html