树链剖分学习笔记

树链剖分是一个很常见的处理树上问题的算法,但之前一直没学 /kk,现在终于把这个坑补了。

这里只讲了重链剖分,不涉及其他剖分方式。

概念

树链剖分是一种把树上问题转化为序列问题的算法。它可以将树上的一条路径所经过的点用序列中的若干条区间表示。

首先我们需要知道几个概念:重儿子、轻儿子、重边、轻边。我们定义对于一个节点 \(u\) 的儿子为 \(v\),我们定义以 \(v\) 为根的子树中节点最多的子树所对应的 \(v\) 为重儿子,其余的 \(v\) 为轻儿子。父亲连向重儿子的边为重边,父亲连向轻儿子的边为轻边。如果同时有多个以 \(v\) 为根的子树节点数相同且最大,则选取任意一个 \(v\) 为重儿子,其余为轻儿子。这样对于每个非叶节点,我们都有一个重儿子和一个重边。

接下来以下列的图举例(图片引用自 OI Wiki):

重链及其性质

同时我们定义重链为由重边组成的极大的链,即以一条重边延伸出只有重边且最大的链。如上图中绿框所示。我们可以很容易地得出树上每个点属于且仅属于一条重链。于是我们可以将树上问题转化成重链上的问题。

我们可以发现,对于每个轻儿子,其子树大小至多为父亲的 \(\frac{1}{2}\)。又通过 LCA 的思想,我们可以得出重链的一个性质:对于树上任意一条路径,所经过的点可被拆分成不超过 \(\log n\) 条重链(头尾的重链可不完整)。于是我们可以将路径问题转化为 \(\log n\) 条序列上的问题。

接下来我们需要将数转化为序列,这里我们可以使用 DFS 序,且遍历时优先遍历重儿子,这样每条重链就可保证在序列中是连续的。

实现方式

我们将求出树上所有重链的操作叫做剖分。

节点定义

首先对于每个节点,我们需要维护它的子树大小,DFS 序,深度以及所在的重链、父节点和重儿子。

1
2
3
4
5
6
struct Node {
std::vector<struct Edge> e;
struct Chain *chain;
int size, dfn, depth;
Node *fa, *ch;
};

剖分

对于树链剖分问题,我们可以先将这棵树进行两次 DFS 以进行剖分。第一次 DFS 求出每颗子树的大小,第二次 DFS 则求出每个点的重儿子。由于我们可以得出每个轻儿子必为重链的开头,于是我们存重链时仅需存储其深度最小的节点,且每遍历到一个轻节点就新建重链,每遍历到一个重节点就将其归入父亲的重链。于是我们可以在第二次 DFS 时同时将树进行剖分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
struct Chain {
Node *top;

Chain(Node *top) : top(top) {}
};

void dfs1(Node *v, Node *fa = nullptr) {
v->size = 1;

for (Edge &e : v->e) {
if (e.t == fa) continue;
e.t->fa = v;
e.t->depth = v->depth + 1;
dfs1(e.t, v);
v->size += e.t->size;
if (!v->ch || v->ch->size < e.t->size) v->ch = e.t;
}
}

void dfs2(Node *v) {
static int ts = 0;
v->dfn = ++ts;

if (!v->fa || v != v->fa->ch) v->chain = new Chain(v);
else v->chain = v->fa->chain;

if (v->ch) dfs2(v->ch);
for (Edge &e : v->e) {
if (e.t->fa == v && e.t != v->ch) {
dfs2(e.t);
}
}
}

inline void split(Node *v) {
v->depth = 1;
dfs1(v);
dfs2(v);
}

维护序列

接下来我们就可在 DFS 序上操作了。对于每个点 \(u\),在序列 \(a\) 上对应的是 \(a_{\text{dfn}_u}\)\(\text{dfn}_u\)\(u\) 的 DFS 序)。我们这里可以选择线段树维护(实际上选择其它任意数据结构均可)。

这时我们求路径的时候可以使用类似于倍增求 LCA 的思想:看两个重链开头的深度,然后更深的往上跳到父亲的重链,同时将这条重链纳入统计。最后跳到同一重链后将两点之间路径纳入统计即可。时间复杂度为 \(O(n \log n)\)。代码演示如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
inline void update(Node *u, Node *v, int w) {
while (u->chain != v->chain) {
if (u->chain->top->depth < v->chain->top->depth) std::swap(u, v);
segment->update(u->chain->top->dfn, u->dfn, w);
u = u->chain->top->fa;
}

if (u->depth > v->depth) std::swap(u, v);
segment->update(u->dfn, v->dfn, w);
}

inline int query(Node *u, Node *v) {
int res = 0;
while (u->chain != v->chain) {
if (u->chain->top->depth < v->chain->top->depth) std::swap(u, v);
res += segment->query(u->chain->top->dfn, u->dfn);
u = u->chain->top->fa;
}

if (u->depth > v->depth) std::swap(u, v);
res += segment->query(u->dfn, v->dfn);

return res;
}

求 LCA

用树链剖分也可以求 LCA,且比倍增求 LCA 更快,常数小。方法和上述维护序列类似,这就不赘述了。

例题

洛谷 P3384「模板」轻重链剖分/树链剖分

这道题是树链剖分模板,同时加了子树操作。在 DFS 序中子树中所有节点是连续的,且开头必为根节点,所以我们只需要查询或修改线段树中的子树所在区间即可。即为 segment->query(nodes[v].dfn, nodes[v].dfn + nodes[v].size - 1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#include <cstdio>
#include <vector>

struct Node {
std::vector<struct Edge> e;
struct Chain *chain;
int size, dfn, depth;
Node *fa, *ch;
};

struct Edge {
Node *s, *t;

Edge(Node *s, Node *t) : s(s), t(t) {}
};

struct Chain {
Node *top;

Chain(Node *top) : top(top) {}
};

inline void addEdge(Node *u, Node *v) {
u->e.push_back(Edge(u, v));
v->e.push_back(Edge(v, u));
}

void dfs1(Node *v, Node *fa = nullptr) {
v->size = 1;

for (Edge &e : v->e) {
if (e.t == fa) continue;
e.t->fa = v;
e.t->depth = v->depth + 1;
dfs1(e.t, v);
v->size += e.t->size;
if (!v->ch || v->ch->size < e.t->size) v->ch = e.t;
}
}

void dfs2(Node *v) {
static int ts = 0;
v->dfn = ++ts;

if (!v->fa || v != v->fa->ch) v->chain = new Chain(v);
else v->chain = v->fa->chain;

if (v->ch) dfs2(v->ch);
for (Edge &e : v->e) {
if (e.t->fa == v && e.t != v->ch) {
dfs2(e.t);
}
}
}

inline void split(Node *v) {
v->depth = 1;
dfs1(v);
dfs2(v);
}

struct SegT {
int l, r;
SegT *lc, *rc;
long long val, tag;

SegT(int l, int r, SegT *lc, SegT *rc) : l(l), r(r), lc(lc), rc(rc), val(0), tag(0) {}

void cover(const long long delta) {
val += (r - l + 1) * delta;
tag += delta;
}

void pushDown() {
if (tag) {
lc->cover(tag);
rc->cover(tag);
tag = 0;
}
}

void update(const int l, const int r, const long long delta) {
if (l > this->r || r < this->l) return;
else if (l <= this->l && r >= this->r) cover(delta);
else {
pushDown();
lc->update(l, r, delta);
rc->update(l, r, delta);
val = lc->val + rc->val;
}
}

long long query(const int l, const int r) {
if (l > this->r || r < this->l) return 0;
else if (l <= this->l && r >= this->r) return val;
else {
pushDown();
return lc->query(l, r) + rc->query(l, r);
}
}

static SegT *build(const int l, const int r) {
if (l > r) return nullptr;
else if (l == r) return new SegT(l, r, nullptr, nullptr);
else {
const int mid = l + (r - l) / 2;
return new SegT(l, r, build(l, mid), build(mid + 1, r));
}
}
} *segment;

inline void update(Node *u, Node *v, long long w) {
while (u->chain != v->chain) {
if (u->chain->top->depth < v->chain->top->depth) std::swap(u, v);
segment->update(u->chain->top->dfn, u->dfn, w);
u = u->chain->top->fa;
}

if (u->depth > v->depth) std::swap(u, v);
segment->update(u->dfn, v->dfn, w);
}

inline long long query(Node *u, Node *v) {
long long res = 0;
while (u->chain != v->chain) {
if (u->chain->top->depth < v->chain->top->depth) std::swap(u, v);
res += segment->query(u->chain->top->dfn, u->dfn);
u = u->chain->top->fa;
}

if (u->depth > v->depth) std::swap(u, v);
res += segment->query(u->dfn, v->dfn);

return res;
}

int main() {
int n, m, r, p;
scanf("%d %d %d %d", &n, &m, &r, &p);

std::vector<Node> nodes(n + 1);
std::vector<long long> val(n + 1);
for (int i = 1; i <= n; i++) scanf("%lld", &val[i]);
for (int i = 1; i <= n - 1; i++) {
int u, v;
scanf("%d %d", &u, &v);
addEdge(&nodes[u], &nodes[v]);
}

split(&nodes[r]);

segment = SegT::build(1, n);
for (int i = 1; i <= n; i++) segment->update(nodes[i].dfn, nodes[i].dfn, val[i]);

while (m--) {
int op;
scanf("%d", &op);

if (op == 1) {
int u, v;
long long w;
scanf("%d %d %lld", &u, &v, &w);
update(&nodes[u], &nodes[v], w);
} else if (op == 2) {
int u, v;
scanf("%d %d", &u, &v);
printf("%lld\n", query(&nodes[u], &nodes[v]) % p);
} else if (op == 3) {
int v;
long long w;
scanf("%d %lld", &v, &w);
segment->update(nodes[v].dfn, nodes[v].dfn + nodes[v].size - 1, w);
} else if (op == 4) {
int v;
scanf("%d", &v);
printf("%lld\n", segment->query(nodes[v].dfn, nodes[v].dfn + nodes[v].size - 1) % p);
}
}

return 0;
}