Splay 学习笔记

一文概括 Splay

长文预警

平衡树

Splay 是一种平衡树,可以解决很多序列问题。而平衡树是一种二叉搜索树。

二叉搜索树则是一种数据结构,对于每一个节点,权值满足左儿子小于根节点小于右儿子,且整棵树的中序遍历是排序后的序列。这样二叉搜索树就可以解决很多问题,其时间复杂度取决于树的深度。

于是这里可以引出一种调试平衡树的常见方法:输出这棵树的中序遍历,看中序遍历是否为排序后序列。

然而在特殊数据下,二叉搜索树很容易退化成一颗链,于是就有了一种优化版本:平衡树。平衡树可以让二叉搜索树随时变化,以求深度保持相对较小,即“平衡”。而 Splay 就是一种平衡树。

如左图是一颗退化成链的二叉搜索树,右图为一颗平衡树。

平衡操作

旋转

Splay 有很多操作,其中旋转操作最为基本的操作之一。旋转分为左旋和右旋。左旋是将一个节点和它的左儿子互换位置,右旋是将一个点和它右儿子互换位置;同时保持中序遍历不变且仍然是一颗排序二叉树。方式如下图(接下来的图片都是使用 Windows 画图鼠绘,画的不好请见谅):

合理运用旋转可以让整棵树保持平衡。

Splay 操作

顾名思义,Splay 操作是 Splay 最重要的操作之一,也是最具有特点的操作。

Splay 操作为将一个点通过旋转移到指定的点的儿子。假设我们需要将 \(x\) 旋转至 \(y\) 的儿子(若不指定 \(y\) 则为将 \(x\) 旋转至根节点),同时 \(x\)\(y\) 之间间隔节点 \(A\)(距离 \(x\) 更近)、\(B\)(距离 \(y\) 更近),则有 \(4\) 种情况:

  1. \(A\)\(y\) 的左儿子,\(x\)\(y\) 的左儿子;
  2. \(A\)\(y\) 的右儿子,\(x\)\(y\) 的右儿子;
  3. \(A\)\(y\) 的左儿子,\(x\)\(y\) 的右儿子;
  4. \(A\)\(y\) 的右儿子,\(x\)\(y\) 的左儿子。

由于 \(1\)\(2\)\(3\)\(4\) 情况基本相同,操作对称。于是我们只需要将 Splay 分为操作 \(1\)\(3\) 两类。Splay 具体操作如下:

  • \(1\) 种情况:\(A\) 向上旋转,\(x\) 向上旋转。这种操作称作 zig-zag 操作。
  • \(3\) 种情况:\(x\) 向上旋转,\(x\) 向上旋转。这种操作称作 zig-zig 操作。

若想将 \(x\) 转至 \(y\) 的儿子,只需要不停进行该操作即可。

为了让整颗树尽量保持平衡,我们需要在每次点操作后都将其 Splay 操作旋转至根节点

已经过证明:使用 Splay 操作旋转节点,可将该平衡树操作时间复杂度降至 \(O(\log n)\),证明略。

二叉搜索树操作

接下来将介绍 Splay 在二叉搜索树种的应用。模板:洛谷 P3369 普通平衡树

为了方便,我们可以先加入 \(+\infty\)\(-\infty\) 两点。

插入操作

由于是二叉搜索树,我们可从根节点开始遍历。每到一个点就看若小于这个节点,就进入左儿子;大于这个节点,就进入右儿子。直到找到的儿子为空,就将其插入至该节点。

插入后为保持平衡,需要将插入的节点 Splay 操作至根节点。

查询排名操作

查询排名使用 Splay 有一种很方便的方法。我们可以把要查询的节点 Splay 至根,排名即为左子树大小。

查询第 \(k\) 大操作

根开始遍历。每到一个点就看这个点的左子树大小:如果大于 \(k\),就进入左子树;如果小于 \(k\),就先将 \(k\) 减去左子树大小和根节点大小,若 \(k \le 0\) 就返回该点,否则进入右子树查询第 \(k\) 大。

查询完后,将查询到的点 Splay 到根。

查询前驱 / 后继

查询前驱 / 后继当然有其他方法,但在 Splay 中有特殊的方法。

我们可以把将要查询的点 Splay 操作至根节点。查询前驱即查询左子树的最右端点,查询后继则是查询右子树的最左端点。

查完前驱 / 后继后,将查询到的点 Splay 到根。

删除

由于 Splay 的特性,我们可以利用特殊方法删除节点。

我们可以先求出这个点的前驱和后继,然后将前驱 Splay 操作至根节点,然后将后继 Splay 操作至前驱的儿子。显然现在前驱的右儿子是后继,左子树只有要删除的一个点。于是接下来直接删除该点即可。

二叉搜索树代码实现

节点操作

对于每个节点,我们需要存储它的儿子、父亲、值、子树大小、值的出现次数以及根节点。

首先我们可以写一些像是更新子树大小这类常用函数备用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
struct Node {
Node *ch[2], *fa, **root;
T x;
int size, count;

Node(Node *fa, Node **root, T x) : fa(fa), root(root), x(x), count(1) {
ch[0] = ch[1] = nullptr;
}

~Node() {
if (ch[0]) delete ch[0];
if (ch[1]) delete ch[1];
}

void maintain() {
size = (ch[0] ? ch[0]->size : 0) + (ch[1] ? ch[1]->size : 0) + count;
}

int relation() {
return this == fa->ch[0] ? 0 : 1;
}

int lSize() {
return ch[0] ? ch[0]->size : 0;
}
} *root;

旋转

按照上述的旋转思路直接模拟即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void rotate() {
Node *o = fa;
int r = relation();

fa = o->fa;
if (fa) fa->ch[o->relation()] = this;

o->ch[r] = ch[r ^ 1];
if (ch[r ^ 1]) ch[r ^ 1]->fa = o;

ch[r ^ 1] = o;
o->fa = this;

if (!fa) *root = this;
o->maintain();
maintain();
}

Splay 操作

按上述方法直接模拟。

1
2
3
4
5
6
7
8
9
10
11
12
void splay(Node *targetFa = nullptr) {
while (fa != targetFa) {
if (fa->fa == targetFa) rotate();
else if (fa->relation() == relation()) {
fa->rotate();
rotate();
} else {
rotate();
rotate();
}
}
}

查询节点的前驱 / 后继

直接将查询的点 Splay 至根即可。

注意查询后将查询到的点 Splay 到根,否则可能会被卡常。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Node *pred() {
splay();
Node *v = ch[0];
while (v->ch[1]) v = v->ch[1];
v->splay();
return v;
}

Node *succ() {
splay();
Node *v = ch[1];
while (v->ch[0]) v = v->ch[0];
v->splay();
return v;
}

二叉搜索树操作

建树 / 删树

建树时插入两 \(\infty\) 点即可。

1
2
3
4
5
6
7
8
Splay() : root(nullptr) {
insert(INF);
insert(-INF);
}

~Splay() {
delete root;
}

查找指定值所在节点

就像二叉搜索树这样搜索即可。从根节点开始遍历。每到一个点就看若小于这个节点,就进入左儿子;大于这个节点,就进入右儿子,直到找到该值即可。

注意找到后将找到的点 Splay 至根。

1
2
3
4
5
6
7
8
9
10
11
12
Node *find(T x) {
Node *v = root;
while (v && x != v->x) {
if (x < v->x) v = v->ch[0];
else v = v->ch[1];
}

if (!v) return nullptr;

v->splay();
return v;
}

插入节点

查找节点,如果查找到节点,则出现次数 \(+1\),否则就用上述方法插入该节点。

最后将插入的节点 Splay 至根。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Node *insert(T x) {
Node *v = find(x);
if (v) {
v->count++;
v->maintain();
return v;
}

Node **target = &root, *fa = nullptr;

while (*target) {
fa = *target;
fa->size++;
if (x < fa->x) target = &fa->ch[0];
else target = &fa->ch[1];
}

*target = new Node(fa, &root, x);
(*target)->splay();

return root;
}

删除节点

若出现次数大于 \(1\),则出现次数 \(-1\)。否则指定节点就用上述方法删除即可。

若指定值则可先查询后删除。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
void erase(T x) {
erase(find(x));
}

void erase(Node *v) {
if (v->count != 1) {
v->splay();
v->count--;
v->maintain();
return;
}

Node *pred = v->pred();
Node *succ = v->succ();

pred->splay();
succ->splay(pred);

delete succ->ch[0];
succ->ch[0] = nullptr;

succ->maintain();
pred->maintain();
}

查询值排名 / 查询第 \(k\)

若有值,则查询值排名只需将该点 Splay 到根,返回左子树大小。若无值,则可先插入该点,查询后删除。

查询第 \(k\) 大则使用上述方法遍历即可。

查完后将查询到的值 Splay 至根。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
int rank(T x) {
Node *v = find(x);
if (v) {
v->splay();
return v->lSize();
} else {
v = insert(x);
int ans = v->lSize();
erase(v);
return ans;
}
}

Node *select(int k) {
k++;
Node *v = root;
while (!(v->lSize() + 1 <= k && v->lSize() + v->count >= k)) {
if (k < v->lSize() + 1) v = v->ch[0];
else {
k -= v->lSize() + v->count;
v = v->ch[1];
}
}
v->splay();
return v;
}

查询值的前驱 / 后继

同查询排名,若有值则直接按上述方法模拟,若无值则插入,查询后删除即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
const T &pred(T x) {
Node *v = find(x);
if (v) return v->pred()->x;
else {
v = insert(x);
const T &ans = v->pred()->x;
erase(v);
return ans;
}
}

const T &succ(T x) {
Node *v = find(x);
if (v) return v->succ()->x;
else {
v = insert(x);
const T &ans = v->succ()->x;
erase(v);
return ans;
}
}

二叉搜索树例题

洛谷 P6136 普通平衡树(数据加强版)(数据比 洛谷 P3369 普通平衡树 强不少)

平衡树板子,直接敲上去就行了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#include <cstdio>
#include <climits>

template <typename T, T INF>
struct Splay {
struct Node {
Node *ch[2], *fa, **root;
T x;
int size, count;

Node(Node *fa, Node **root, T x) : fa(fa), root(root), x(x), count(1) {
ch[0] = ch[1] = nullptr;
}

~Node() {
if (ch[0]) delete ch[0];
if (ch[1]) delete ch[1];
}

void maintain() {
size = (ch[0] ? ch[0]->size : 0) + (ch[1] ? ch[1]->size : 0) + count;
}

int relation() {
return this == fa->ch[0] ? 0 : 1;
}

void rotate() {
Node *o = fa;
int r = relation();

fa = o->fa;
if (fa) fa->ch[o->relation()] = this;

o->ch[r] = ch[r ^ 1];
if (ch[r ^ 1]) ch[r ^ 1]->fa = o;

ch[r ^ 1] = o;
o->fa = this;

if (!fa) *root = this;
o->maintain();
maintain();
}

void splay(Node *targetFa = nullptr) {
while (fa != targetFa) {
if (fa->fa == targetFa) rotate();
else if (fa->relation() == relation()) {
fa->rotate();
rotate();
} else {
rotate();
rotate();
}
}
}

Node *pred() {
splay();
Node *v = ch[0];
while (v->ch[1]) v = v->ch[1];
v->splay();
return v;
}

Node *succ() {
splay();
Node *v = ch[1];
while (v->ch[0]) v = v->ch[0];
v->splay();
return v;
}

int lSize() {
return ch[0] ? ch[0]->size : 0;
}
} *root;

Splay() : root(nullptr) {
insert(INF);
insert(-INF);
}

~Splay() {
delete root;
}

Node *find(T x) {
Node *v = root;
while (v && x != v->x) {
if (x < v->x) v = v->ch[0];
else v = v->ch[1];
}

if (!v) return nullptr;

v->splay();
return v;
}

Node *insert(T x) {
Node *v = find(x);
if (v) {
v->count++;
v->maintain();
return v;
}

Node **target = &root, *fa = nullptr;

while (*target) {
fa = *target;
fa->size++;
if (x < fa->x) target = &fa->ch[0];
else target = &fa->ch[1];
}

*target = new Node(fa, &root, x);
(*target)->splay();

return root;
}

void erase(T x) {
erase(find(x));
}

void erase(Node *v) {
if (v->count != 1) {
v->splay();
v->count--;
v->maintain();
return;
}

Node *pred = v->pred();
Node *succ = v->succ();

pred->splay();
succ->splay(pred);

delete succ->ch[0];
succ->ch[0] = nullptr;

succ->maintain();
pred->maintain();
}

int rank(T x) {
Node *v = find(x);
if (v) {
v->splay();
return v->lSize();
} else {
v = insert(x);
int ans = v->lSize();
erase(v);
return ans;
}
}

Node *select(int k) {
k++;
Node *v = root;
while (!(v->lSize() + 1 <= k && v->lSize() + v->count >= k)) {
if (k < v->lSize() + 1) v = v->ch[0];
else {
k -= v->lSize() + v->count;
v = v->ch[1];
}
}
v->splay();
return v;
}

const T &pred(T x) {
Node *v = find(x);
if (v) return v->pred()->x;
else {
v = insert(x);
const T &ans = v->pred()->x;
erase(v);
return ans;
}
}

const T &succ(T x) {
Node *v = find(x);
if (v) return v->succ()->x;
else {
v = insert(x);
const T &ans = v->succ()->x;
erase(v);
return ans;
}
}
};

int main() {
int n, m;
scanf("%d %d", &n, &m);

Splay<int, INT_MAX> splay;
for (int i = 0; i < n; i++) {
int x;
scanf("%d", &x);
splay.insert(x);
}

int last = 0, ans = 0;
while (m--) {
int op, x;
scanf("%d %d", &op, &x);
x ^= last;
if (op == 1) splay.insert(x);
else if (op == 2) splay.erase(x);
else if (op == 3) last = splay.rank(x);
else if (op == 4) last = splay.select(x)->x;
else if (op == 5) last = splay.pred(x);
else if (op == 6) last = splay.succ(x);
if (op >= 3 && op <= 6) ans ^= last;
}

printf("%d\n", ans);

return 0;
}

序列操作

由于 Splay 的形态特殊,我们可以用它来解决序列问题。这时 Splay 的中序遍历即为序列的各元素顺序。所以说 Splay 不一定是有序的,它的中序遍历是可自定义的,可根据情况调整用途

序列插入

假设我们需要将序列 \(a\) 插入到 \(x\)\(y\) 之间,只需要做下列 \(3\) 步:

  • \(x\) Splay 操作至根;
  • \(y\) Splay 操作至 \(x\) 的右儿子;
  • 由于 \(y\)\(x\) 的后继,故 \(y\) 的左子树必为空。此时只需将 \(a\) 插入 \(y\) 的左子树即可。

序列删除

与序列插入相似,假设我们需要需将 \(x\)\(y\) 之间的序列 \(a\) 删除,依然只需 \(3\) 步:

  • \(x\) Splay 操作至根;
  • \(y\) Splay 操作至 \(x\) 的右儿子;
  • \(y\) 的左子树即为 \(x\)\(y\) 之间的序列,即序列 \(a\)。于是我们只需删除 \(y\) 的左子树即可。

区间翻转

这里可能区间翻转多次,所以可以像线段树一样作懒标记。然后将左右子树交换即可。最终传到叶子节点的时候由于每个都交换,即可达成翻转区间。

遍历到区间的时候就把懒标记下传,然后交换左右子树即可。

序列例题

洛谷 P3391 文艺平衡树

区间翻转板子,敲就完事了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include <cstdio>
#include <algorithm>

const int MAXN = 1e5;

template <typename T>
struct Splay {
struct Node {
Node *ch[2], *fa, **root;
T x;
int size;
bool rev, bound;

Node(Node *fa, Node **root, T x, bool bound = false) : fa(fa), root(root), x(x), size(1), rev(false), bound(bound) {
ch[0] = ch[1] = nullptr;
}

void maintain() {
size = (ch[0] ? ch[0]->size : 0) + (ch[1] ? ch[1]->size : 0) + 1;
}

void pushDown() {
if (rev) {
std::swap(ch[0], ch[1]);
if (ch[0]) ch[0]->rev ^= true;
if (ch[1]) ch[1]->rev ^= true;
rev = false;
}
}

int relation() {
return this == fa->ch[0] ? 0 : 1;
}

void rotate() {
fa->pushDown();

Node *o = fa;
int r = relation();

fa = o->fa;
if (fa) fa->ch[o->relation()] = this;

o->ch[r] = ch[r ^ 1];
if (ch[r ^ 1]) ch[r ^ 1]->fa = o;

ch[r ^ 1] = o;
o->fa = this;

if (!fa) *root = this;
o->maintain();
maintain();
}

void splay(Node *targetFa = nullptr) {
while (fa != targetFa) {
if (fa->fa) fa->fa->pushDown();
fa->pushDown();

if (fa->fa == targetFa) rotate();
else if (fa->relation() == relation()) {
fa->rotate();
rotate();
} else {
rotate();
rotate();
}
}
}

int lSize() {
pushDown();
return ch[0] ? ch[0]->size : 0;
}
} *root;

Splay() : root(nullptr) {}

Node *build(int *l, int *r, Node *fa) {
if (l > r) return nullptr;

int *mid = l + (r - l) / 2;
Node *v = new Node(fa, &root, *mid);
v->ch[0] = build(l, mid - 1, v);
v->ch[1] = build(mid + 1, r, v);

v->maintain();
return v;
}

void build(int *a, int n) {
root = new Node(nullptr, &root, -1, true);
root->ch[1] = new Node(root, &root, -1, true);
root->maintain();

root->ch[1]->ch[0] = build(a + 1, a + n, root->ch[1]);

root->ch[1]->maintain();
root->maintain();
}

Node *select(int k) {
Node *v = root;
while (k != v->lSize()) {
if (k < v->lSize()) v = v->ch[0];
else k -= v->lSize() + 1, v = v->ch[1];
}
v->splay();
return v;
}

Node *select(int l, int r) {
Node *pred = select(l - 1), *succ = select(r + 1);
pred->splay();
succ->splay(pred);
return succ->ch[0];
}

void reverse(int l, int r) {
select(l, r)->rev ^= true;
}

void fetch(int *a) {
int *p = a + 1;
dfs(root, p);
}

void dfs(Node *v, int *&p) {
if (!v) return;
v->pushDown();
dfs(v->ch[0], p);
if (!v->bound) *p++ = v->x;
dfs(v->ch[1], p);
}
};

int main() {
int n;
scanf("%d", &n);

static int a[MAXN + 1];
for (int i = 1; i <= n; i++) a[i] = i;

Splay<int> splay;
splay.build(a, n);

int m;
scanf("%d", &m);
while (m--) {
int l, r;
scanf("%d %d", &l, &r);
splay.reverse(l, r);
}

splay.fetch(a);

for (int i = 1; i <= n; i++) printf("%d ", a[i]);

return 0;
}