树状数组学习笔记

概念

树状数组是一种维护线性数组前缀和的特殊方法,使用了类似于快速幂等的二进制分组的优势,对数组进行分段,从而可以快速求出该数组的前缀和的树形数据结构。

这一点与线段树类似。树状数组能干的事情线段树都能干,线段树能干的事情树状数组不一定能干。但树状数组相对线段树代码更短更好写,故在只涉及单点修改的时候树状数组更常用。

树状数组的工作原理如下图(图片引用自 OI Wiki):

\(a\) 数组分为图示若干段,其中 \(c\) 数组存的是区间和。

前置知识

详见之前《位运算学习笔记》博客

这里想补充一点,就是 lowbit 计算。

lowbit 计算的目的是求出一个数在二进制下最后一位 1 和之后的所有 0 所组成的数。

具体求法为 x = x & -x,接下我将会利用数字 20 解释其原理。

首先将 20 的二进制表示出来,为 010100。

然后将 20 的负数表示出来,在补码的表示方法为将 20 按位取反,再 +1,即为 101100。

再将两个数进行与运算,得到二进制数 000100,十进制数为 4。该数即为 20 在二进制下的最后一个 1 和之后的 0 所组成的数。

lowbit 计算为树状数组的基础运算。

实现方法

那么怎么实现对数组的分段?这里就需要使用之前介绍的 \(\text{lowbit}\) 计算。

1
2
3
inline int lowbit(int x) {
return x & -x;
}

树状数组支持两种操作:单点修改,前缀求和。

初始化

首先在使用树状数组之间,我们需要将其初始化。

1
2
int c[MAXN + 1];    //存储区间和
memset(c, 0, sizeof(c));

单点修改

树状数组可以将单点单独加一个值,例如修改 \(a_5\) 工作原理如下(对应上图序号)(画图丑请见谅):

graph LR
a5 --> c6 --> c8

即为逐级上升,在经过的结点都加上这个值即可。

关于如何求上升后的坐标,使用 \(\text{lowbit}\) 即可。

\[ \begin{cases} 5 + \text{lowbit}(5) = 6 & a_5 \rightarrow c_6 \\ 6 + \text{lowbit}(6) = 8 & c_6 \rightarrow c_8 \end{cases} \]

代码如下:

1
2
3
4
5
6
void add(int x, int k) {
while (x <= n) { // 防止数组越界
c[x] = c[x] + k;
x = x + lowbit(x);
}
}

前缀查询

树状数组最主要的功能就是求前缀和,例如查询 \(a_7\) 工作原理如下:

graph LR
a7 --> c6 --> c4

即为不停地查询上一层的前一个,同时将值加入最终结果即可。

求前一层的前一个也可以使用前缀和:

\[ \begin{cases} 7 - \text{lowbit}(7) = 6 & a_7 \rightarrow c_6 \\ 6 - \text{lowbit}(6) = 4 & c_6 \rightarrow c_4 \end{cases} \]

代码如下:

1
2
3
4
5
6
7
8
int getsum(int x) {
int ans = 0;
while (x >= 1) {
ans = ans + c[x];
x = x - lowbit(x);
}
return ans;
}

以上为树状数组最基础的操作

例题

洛谷 P3374 【模板】树状数组 1

简单的模板题,结合上面的知识即可 AC。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <cstdio>
#include <cstring>

const int MAXN = 500000;

int c[MAXN + 1], n;

inline int lowbit(int x) {
return x & -x;
}

inline void add(int x, int k) {
while (x <= n) {
c[x] = c[x] + k;
x = x + lowbit(x);
}
}

inline int getsum(int x) {
int ans = 0;
while (x >= 1) {
ans = ans + c[x];
x = x - lowbit(x);
}
return ans;
}

int main() {
int m;
memset(c, 0, sizeof(c));

scanf("%d%d", &n, &m);

for (int i = 1, x; i <= n; i++) {
scanf("%d", &x);
add(i, x);
}

while (m--) {
int m, x, y;
scanf("%d%d%d", &m, &x, &y);

if (m == 1) add(x, y);
else printf("%d\n", getsum(y) - getsum(x - 1));
}

return 0;
}