最近公共祖先学习笔记

概念

最近公共祖先(英文 Lowest Common Ancestor,简称 LCA)定义为树种两个点的公共祖先里面,离根最远的一个。

实现方法

朴素算法

先算出两点的深度,取深度最小值,不停寻找父结点直到该深度。之后两点同时找父结点,直到两者的父结点相同,该父结点即为最近公共祖先。易得时间复杂度为 \(O(n)\)

倍增算法

预处理

记录一个数组 fa[u][i],记录的是结点 \(u\) 向上寻找 \(2^i\) 次父结点所达到的结点,如果不存在则为 \(0\)

可通过二进制拆分 \(u\) 的深度(拆分过程详见之前《位运算学习笔记》博客中的“快速幂”章节 ),在 \(O(\log n)\) 的时间复杂度下算出 fa[u][i]。利用该方法可以利用递归在 \(O(n \log n)\) 的时间复杂度内预处理完 fa 数组。

下列为预处理 fa 的代码:

1
2
3
4
5
6
7
8
9
void getFa(int u, int f) {
fa[u][0] = f;
dep[u] = dep[f] + 1;
for (int i = 1; i <= 20; i++) fa[u][i] = fa[ fa[u][i - 1] ][i - 1];
for (int v : e[u]) {
if (v == f) continue;
getFa(v, u);
}
}

寻找 LCA

与朴素算法类似,不同之处是利用二进制和 fa 数组将复杂度从 \(O(n)\) 优化到 \(O(\log n)\)

原理是从远及近遍历 fa 数组,直到两结点同深度的祖先不相同,此时 \(LCA\) 必定在两祖先之上,然后在祖先之上再次使用 fa 数组遍历,直到 fa[u][0] == fa[v][0] 时,\(LCA\)\(u\)\(v\) 的父结点。该过程相当于将 \(LCA\)\(u\)\(v\) 的深度差拆分成若干个 \(2\) 的幂,减少了运算次数,加快运行效率。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int lca(int u, int v) {
if (dep[u] < dep[v]) std::swap(u, v);
for (int i = 20; i >= 0; i--) {
if (dep[ fa[u][i] ] >= dep[v] ) {
u = fa[u][i];
}
}
if (u == v) return u;
for (int i = 20; i >= 0; i--) {
if (fa[u][i] != fa[v][i]) {
u = fa[u][i];
v = fa[v][i];
}
}
return fa[u][0];
}

算法复杂度统计

算法 预处理 q 次查询
朴素算法 \(O(n)\) \(O(qn)\)
倍增算法 \(O(n \log n)\) \(O(q \log n)\)
欧拉序转 RMQ \(O(n \log n)\) \(O(q)\)
Tarjan \(O(n)\) \(O(n + q)\)
树链剖分 \(O(n)\) \(O(q \log n)\)

例题

LibreOJ #10130. 「一本通 4.4 例 1」点的距离

本题由于是树,可利用树的特性做题。

这里可以使用 \(LCA\)。因 \(x\)\(y\) 的最短路径种必定会经过 \(LCA\),故可以将该题转化为求 \(LCA\) 分别与点 \(x\) 和点 \(y\) 的距离。

先求出点 \(x\) 与点 \(y\)\(LCA\),再将 \(LCA\) 与点 \(x\)、点 \(y\) 的深度差相加即为答案。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <cstdio>
#include <iostream>
#include <vector>

const int MAXN = 100000;

int n;
std::vector<int> e[MAXN + 1];
int fa[MAXN + 1][25], dep[MAXN + 1];

void getFa(int u, int f) {
fa[u][0] = f;
dep[u] = dep[f] + 1;
for (int i = 1; i <= 20; i++) fa[u][i] = fa[ fa[u][i - 1] ][i - 1];
for (int v : e[u]) {
if (v == f) continue;
getFa(v, u);
}
}

int lca(int u, int v) {
if (dep[u] < dep[v]) std::swap(u, v);
for (int i = 20; i >= 0; i--) {
if (dep[ fa[u][i] ] >= dep[v] ) {
u = fa[u][i];
}
}
if (u == v) return u;
for (int i = 20; i >= 0; i--) {
if (fa[u][i] != fa[v][i]) {
u = fa[u][i];
v = fa[v][i];
}
}
return fa[u][0];
}

int main() {
std::cin >> n;
for (int i = 0, x, y; i < n - 1; i++) {
scanf("%d%d", &x, &y);
e[x].push_back(y);
e[y].push_back(x);
}

getFa(1, 0);

int q;
std::cin >> q;
while (q--) {
int x, y;
scanf("%d%d", &x, &y);

int root = lca(x, y);

printf("%d\n", dep[x] + dep[y] - 2 * dep[root]);
}

return 0;
}