[数据结构]树状数组详解

前言

之前由于树状数组和线段树的修改和查询操作复杂度都是 O ( l o g 2 n ) O(log_2n) O(log2n),并且树状数组还只能同时支持单点修改和区间查询(其实骚操作还蛮多,不过个人以为dark不必),无法像线段树那样同时支持区间修改和区间查询。所以博主在高中竞赛期间一直使用线段树代替树状数组,偶尔被卡常就直接照抄一个树状数组模板,没有对这个数据结构进行过深刻的理解和学习,遂于今日还债,争取这个寒假Python学习和竞赛复习两开花好吧。

为什么是树状数组?

上文写的都是为什么不想学树状数组,但既然我开了这个坑,说明树状数组还是有其价值所在,在某些时候有着线段树不具有的优势。

极小的常数

虽然树状数组与线段树的理论时间复杂度都是 O ( l o g 2 n ) O(log_2n) O(log2n),但是树状数组的代码由于足够简单、操作较少,在实际运行过程中比线段树更具效率,在相同情况下运行时间往往远低于线段树,在各种卡常数场景下极具应用价值。

线性的空间复杂度

这也是树状数组相比线段树独特的优势所在,由于树状数组空间复杂度仅为 O ( n ) O(n) O(n),仅占用了线性的空间,所以被广泛应用在嵌套类的数据结构中。此类数据结构由于空间复杂度的叠加往往会遭遇空间不足的窘境,此时线性复杂度的树状数组就成为了嵌套的首选。

简短的代码

树状数组的原理非常简洁(但不一定好理解),虽然功能也相应的比较简单,但是代码实现也变得非常容易。即使没有搞清楚树状数组的原理,直接硬背模板也不失为一个选择。同时对于已经理解了的选手,实现树状数组也比实现线段树要更快。

原理详解

化一维为树状的储存策略

树状数组的独特在于人为地将一维的数组排布为树状,并让一个节点储存更多的信息:

在这里插入图片描述
以大小为 8 8 8的树状数组为例: 1 , 3 , 5 , 7 1,3,5,7 1,3,5,7这四个节点只储存了自己对应位置的数组的信息; 2 2 2号节点储存了 1 , 2 1,2 1,2两个位置的信息, 6 6 6号节点储存了 5 , 6 5,6 5,6两个位置的信息; 4 4 4号节点储存了 1 , 2 , 3 , 4 1,2,3,4 1,2,3,4四个位置的信息; 8 8 8号节点则储存了 1 ∼ 8 1\sim 8 18所有位置的信息。

到此为止,我们可以观察出树状数组的信息储存策略:

  1. 单个节点储存的位置信息数量等于于节点编号的最大的、是二的次幂的约数。如 6 6 6的约数有 1 , 2 , 3 , 6 1,2,3,6 1,2,3,6,其中是二的次幂的只有 2 2 2,所以 6 6 6号节点储存了两个位置的信息;对于 8 8 8号节点,它的约数有 1 , 2 , 4 , 8 1,2,4,8 1,2,4,8,这些都是二的次幂,其中 8 8 8是最大的,所以 8 8 8号节点储存 8 8 8个位置的信息。

  2. 每个节点只会储存编号小于等于自己编号的位置的信息。

  3. 每个节点储存信息的位置编号是连续的。

根据以上三点,我们就可以求出任意编号的节点应该储存哪些位置的信息,即每个节点应该储存包括自己对应的位置在内的、编号小于自己且连续的、数量为自己节点编号的最大二次幂约数的位置的信息。

但是仅仅凭借这个规则,我们并不能快速求出每个节点的储存范围,而如果修改了一个位置的信息,该向上更新哪些节点也并不好求出,所以我们需要更进一步的挖掘树状数组储存策略的内涵。

二进制下的质变

lowbit函数

如果将上述规则放在二进制下解读,这个规则或许没有那么复杂:

在这里插入图片描述
可以发现,每个节点储存的位置信息数等于 2 2 2的节点编号末尾连续 0 0 0的个数次幂,如 110 110 110末尾有 1 1 1 0 0 0,那么 6 6 6号节点就储存了 2 1 2^1 21 2 2 2个位置的信息。

再结合第二、三条规则,可以发现对于一个二进制的编号为 A 1 00 ⋯ 0 ⏟ n 个 0 , n ≥ 0 A1\underbrace{00\cdots0}_{n个0,n\ge0} A1n0,n0 000 A A A为任意 01 01 01串)的节点,它会储存从编号再 [ A 00 ⋯ 0 ⏟ n 个 0 1 , A 1 00 ⋯ 0 ⏟ n 个 0 ] [A\underbrace{00\cdots0}_{n个0}1, A1\underbrace{00\cdots0}_{n个0}] [An0 0001,A1n0 000]区间的所有位置的信息。比如 10 10 10号节点(二进制下为 1010 1010 1010 A A A部分为前两位 10 10 10,同时以 10 10 10结尾)就会储存编号在 1001 ∼ 1010 1001\sim 1010 10011010(即 9 ∼ 10 9\sim 10 910)的位置信息。

所以,问题的关键在于快速求出一个数二进制下末尾的 100 ⋯ 0 100\cdots 0 1000串对应的数值。

在计算机里,通过二进制位运算 i & ( − i ) i \& (-i) i&(i)就可以 O ( 1 ) O(1) O(1)求出对应节点记录位置信息的数量。这个操作得以实现是基于计算机内对负数按位取反再加一的储存方式。

对于一个二进制正整数 A 1 00 ⋯ 0 ⏟ n 个 0 , n ≥ 0 A1\underbrace{00\cdots0}_{n个0,n\ge0} A1n0,n0 000,其在计算机内的存储为:
0   A 1 00 ⋯ 0 ⏟ n 个 0 0\ A1\underbrace{00\cdots0}_{n个0} 0 A1n0 000(首位的 0 0 0表示这是一个正数)

而它的相反数则是先对 0   A 1 00 ⋯ 0 ⏟ n 个 0 0\ A1\underbrace{00\cdots0}_{n个0} 0 A1n0 000按位取反得到:
1   i n v ( A ) 0 11 ⋯ 1 ⏟ n 个 1 1\ inv(A)0\underbrace{11\cdots1}_{n个1} 1 inv(A)0n1 111 i n v ( A ) inv(A) inv(A)表示 01 01 01 A A A按位取反后得到的串)

再加上 1 1 1得到:
1   i n v ( A ) 1 00 ⋯ 0 ⏟ n 个 0 1\ inv(A)1\underbrace{00\cdots0}_{n个0} 1 inv(A)1n0 000

不难发现,一个数和它的相反数在计算机中储存时恰好只有我们要求的那部分末尾是相同的,而前面的部分每一位都是取反关系。所以,只需要将一个数和它的相反数“ & \& &”起来,就能得到我们想要求的数值。这个函数有个特有的名称,叫做 l o w b i t lowbit lowbit函数,即:
l o w b i t ( x ) = x & ( − x ) lowbit(x)=x\&(-x) lowbit(x)=x&(x)

单点修改

从上一部分我们可以看到,编号为 A 1 00 ⋯ 0 ⏟ n 个 0 , n ≥ 0 A1\underbrace{00\cdots0}_{n个0,n\ge0} A1n0,n0 000的节点,会记录 A 00 ⋯ 0 ⏟ n 个 0 1 ∼ A 1 00 ⋯ 0 ⏟ n 个 0 A\underbrace{00\cdots0}_{n个0}1\sim A1\underbrace{00\cdots0}_{n个0} An0 0001A1n0 000之间所有位置的数组的信息。同样的,当我们修改了某一个点的信息,就需要上溯每一个储存了该位置信息的节点,做出相应修改。

我们先尝试找到某个节点上溯的第一个节点:
按照上述结论反推,对于一个编号二进制为 A 01 B A01B A01B A A A为任意 01 01 01串, B B B为形如 11 ⋯ 1 ⏟ n 个 1 , n ≥ 0   00 ⋯ 0 ⏟ m 个 0 , m ≥ 0 \underbrace{11\cdots 1}_{n个1,n\ge0}\ \underbrace{00\cdots 0}_{m个0,m\ge 0} n1,n0 111 m0,m0 000 01 01 01串)的节点,包含了该节点信息的节点的编号一定大于 A 01 B A01B A01B,且编号的二进制末尾一定为 1 00 ⋯ 0 ⏟ p 个 0 , p ≥ n + m + 1 1\underbrace{00\cdots 0}_{p个0,p\ge n+m+1} 1p0,pn+m+1 000(因为末尾要能够大于 01 B 01B 01B这个串)。而且上溯的第一个节点编号一定是满足这些条件的编号中最小的一个,所以对应的节点编号显然为 A 1 00 ⋯ 0 ⏟ n + m + 1 个 0 A1\underbrace{00\cdots 0}_{n+m+1个0} A1n+m+10 000,同时可以观察到:
A 1 00 ⋯ 0 ⏟ n + m + 1 个 0 = A 01 B + 1 00 ⋯ 0 ⏟ m 个 0 = A 01 B + l o w b i t ( A 01 B ) A1\underbrace{00\cdots 0}_{n+m+1个0}=A01B+1\underbrace{00\cdots0}_{m个0}=A01B+lowbit(A01B) A1n+m+10 000=A01B+1m0 000=A01B+lowbit(A01B)
原来,编号为 v v v的节点上溯的第一个节点的编号就是节点自己的编号加上 l o w b i t lowbit lowbit值,即 v + l o w b i t ( v ) v+lowbit(v) v+lowbit(v),那么只要不断地加上当前节点的 l o w b i t lowbit lowbit值,就可以不断上溯,完成更新操作。

所以我们可以得到下面的执行单点修改函数(例子中为给位置为 v v v的数加上 Δ \Delta Δ):

void add(int v,int delta){
    
    for(;v<=n;v+=lb(v))num[v]+=delta;}

由于每次上溯后,编号的二进制位末尾的 100 ⋯ 0 100\cdots 0 1000串中 0 0 0的个数都会至少增加一个,在 l o g 2 n log_2n log2n次运算之内,该节点编号就会大于 n n n,所以该操作的时间复杂度为 O ( l o g 2 n ) O(log_2 n) O(log2n)

区间查询

如果我们想要完成对区间 [ 1 , v ] [1,v] [1,v]的查询,最简单粗暴的办法就是用for循环将 1 ∼ v 1\sim v 1v中的所有位置都遍历一遍,但是很显然这样没有利用上我们辛辛苦苦搭建起来的树状数组,考虑怎样利用上那些记录了多个位置信息的节点。

根据树状数组的存储策略,编号为 A 1 00 ⋯ 0 ⏟ n 个 0 , n ≥ 0 A1\underbrace{00\cdots0}_{n个0,n\ge 0} A1n0,n0 000的节点,会记录 A 00 ⋯ 0 ⏟ n 个 0 1 ∼ A 1 00 ⋯ 0 ⏟ n 个 0 A\underbrace{00\cdots0}_{n个0}1\sim A1\underbrace{00\cdots0}_{n个0} An0 0001A1n0 000之间所有位置的数组的信息。所以我们只需要 A 1 00 ⋯ 0 ⏟ n 个 0 A1\underbrace{00\cdots0}_{n个0} A1n0 000这一个节点就可以得到 [ A 00 ⋯ 0 ⏟ n 个 0 1 , A 1 00 ⋯ 0 ⏟ n 个 0 ] [A\underbrace{00\cdots0}_{n个0}1,A1\underbrace{00\cdots0}_{n个0}] [An0 0001,A1n0 000]这个区间里数组的信息,那么只要找到 [ 1 , v ] [1,v] [1,v]区间中储存了多个位置的数组信息且不重复的节点,将它们储存的信息整合在一起,就能得到最终结果。

假设我们要查询的区间为 [ 1 , A 1 00 ⋯ 0 ⏟ n 个 0 , n ≥ 0 1 00 ⋯ 0 ⏟ m 个 0 , m ≥ 0 ] [1,A1\underbrace{00\cdots0}_{n个0,n\ge 0}1\underbrace{00\cdots0}_{m个0,m\ge0}] [1,A1n0,n0 0001m0,m0 000],首先这个节点本身就会储存 [ A 1 00 ⋯ 0 ⏟ n + m 个 0 1 , A 1 00 ⋯ 0 ⏟ n 个 0 1 00 ⋯ 0 ⏟ m 个 0 ] [A1\underbrace{00\cdots0}_{n+m个0}1,A1\underbrace{00\cdots0}_{n个0}1\underbrace{00\cdots0}_{m个0}] [A1n+m0 0001,A1n0 0001m0 000]这个区间的数组信息,所以问题就转换为求 [ 1 , A 1 00 ⋯ 0 ⏟ n + m + 1 个 0 ] [1,A1\underbrace{00\cdots0}_{n+m+1个0}] [1,A1n+m+10 000]这个区间内的信息,我们成功的将问题的范围从 A 1 00 ⋯ 0 ⏟ n 个 0 1 00 ⋯ 0 ⏟ m 个 0 A1\underbrace{00\cdots0}_{n个0}1\underbrace{00\cdots0}_{m个0} A1n0 0001m0 000缩小到了 A 1 00 ⋯ 0 ⏟ n + m + 1 个 0 A1\underbrace{00\cdots0}_{n+m+1个0} A1n+m+10 000。同时,不难发现:
A 1 00 ⋯ 0 ⏟ n + m + 1 个 0 = A 1 00 ⋯ 0 ⏟ n 个 0 1 00 ⋯ 0 ⏟ m 个 0 − l o w b i t ( A 1 00 ⋯ 0 ⏟ n 个 0 1 00 ⋯ 0 ⏟ m 个 0 ) A1\underbrace{00\cdots0}_{n+m+1个0}=A1\underbrace{00\cdots0}_{n个0}1\underbrace{00\cdots0}_{m个0}-lowbit(A1\underbrace{00\cdots0}_{n个0}1\underbrace{00\cdots0}_{m个0}) A1n+m+10 000=A1n0 0001m0 000lowbit(A1n0 0001m0 000)
因此,只需要不断地减去当前查询区间右端点的 l o w b i t lowbit lowbit值,就能快速地缩小查询范围,并从右端点对应地节点直接获取部分信息,最终组成答案,完成查询操作。

所以我们得到了下面的区间查询函数(例子中为查询区间和):

int ask(int v){
    
    int re=0;for(;v;v-=lb(v))re+=num[v];return re;}

由于每次减去自身的 l o w b i t lowbit lowbit后,右端点的二进制都会少掉最右端的 1 1 1,在 l o g 2 n log_2n log2n次运算之内右端点就会变成 0 0 0,所以树状数组可以以 O ( l o g 2 n ) O(log_2n) O(log2n)的复杂度完成对任意以 1 1 1为左端点的区间 [ 1 , v ] [1,v] [1,v]的查询,两次查询的结果相减就可以求出任意区间。

例题

Luogu3374 【模板】树状数组 1

原题链接:https://www.luogu.com.cn/problem/P3374

单点修改+区间查询,对应上文讲解中使用的例子,这里直接给出代码:

#include<bits/stdc++.h>
#define lb(i) (i&(-i))
using namespace std;
const int M=5e5+5;
int n,m,num[M];
void add(int v,int delta){
    
    for(;v<=n;v+=lb(v))num[v]+=delta;}
int ask(int v){
    
    int re=0;for(;v;v-=lb(v))re+=num[v];return re;}
void in()
{
    
    
    scanf("%d%d",&n,&m);
    for(int i=1,a;i<=n;++i)scanf("%d",&a),add(i,a);
}
void ac()
{
    
    
    for(int i=1,a,b,c;i<=m;++i)
    {
    
    
        scanf("%d%d%d",&a,&b,&c);
        if(a-1)printf("%d\n",ask(c)-ask(b-1));
        else add(b,c);
    }
}
int main()
{
    
    
    in(),ac();
    system("pause");
}

Luogu3368 【模板】树状数组 2

原题链接:https://www.luogu.com.cn/problem/P3368

区间修改+单点查询,通过差分就可以转换成单点修改+区间查询。

差分就是将原数列 { a i } \{a_i\} { ai}修改为原来的数与前一个位置的数之差 b i = a i − a i − 1 b_i=a_i-a_{i-1} bi=aiai1得到数列 { b i } \{b_i\} { bi},这样新数列的前缀和就等于原数列对应位置的数,即:
∑ i = 1 k b i = a k \sum_{i=1}^{k}b_i=a_k i=1kbi=ak

当我们想要对区间 [ x , y ] [x,y] [x,y]中的每一个数都加上 Δ \Delta Δ时,只需要对差分数列的第 x x x项加上 Δ \Delta Δ,第 y + 1 y+1 y+1项减去 Δ \Delta Δ就完成了操作。

代码如下:

#include<bits/stdc++.h>
#define lb(i) (i&(-i))
using namespace std;
const int M=5e5+5;
int n,m,num[M];
void add(int v,int delta){
    
    for(;v<=n;v+=lb(v))num[v]+=delta;}
int ask(int v){
    
    int re=0;for(;v;v-=lb(v))re+=num[v];return re;}
void in()
{
    
    
    scanf("%d%d",&n,&m);
    for(int i=1,a=0,b;i<=n;++i)scanf("%d",&b),add(i,b-a),a=b;
}
void ac()
{
    
    
    for(int i=1,a,b,c,d;i<=m;++i)
    {
    
    
        scanf("%d%d",&a,&b);
        if(a-1)printf("%d\n",ask(b));
        else
        {
    
    
            scanf("%d%d",&c,&d);
            add(b,d),add(c+1,-d);
        }
    }
}
int main()
{
    
    
    in(),ac();
    system("pause");
}

猜你喜欢

转载自blog.csdn.net/ShadyPi/article/details/113091393