「YbtOJ CSP-S 模拟赛」砍苹果树 - 树上差分

题意描述

YbtOJ 链接(需付费)

小 K 有一棵 \(n\) 个点,\(n - 1\) 条边的苹果树,将树上的边称为 A 类边。

小 K 还往这棵树上加上了 \(m\) 条边,称加上的边为 B 类边。

作为小 K 的好朋友,你想要砍掉小 K 的苹果树,但是你发现砍掉一条边不一定能使苹果树不连通,于是你需要求出:有多少选取 恰好一条 A 类边和恰好一条 B 类边 的方案,使得这两条边删去之后,原图不连通。

两种方案不同当且仅当一条边在第一种方案中被删除了但在第二种方案中没有被删除。

解题思路

我们可以使用类似于树形 DP 的思想考虑这个题。我们只使用 A 类边构成一颗树,显然我们可以发现,对于一条 A 类边 \(u \rightarrow v\),对于以 \(v\) 为根的子树,该颗子树内有 \(k\) 条通往子树外的 B 类边,可分一下情况讨论:

  • \(k = 0\):显然删除 \(u \rightarrow v\) 与任意 B 类边均满足情况,答案增加 \(m\)
  • \(k = 1\):显然删除 \(u \rightarrow v\) 与通往子树外的那一条 B 类边满足情况,答案增加 \(1\)
  • \(k > 1\):无解。

于是我们可以用树上差分,对点统计向外连边的个数,然后 DFS 一遍即可。

代码演示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include <cstdio>
#include <iostream>
#include <vector>

const int MAXN = 3e5;
const int LOG_MAXN = 19;

struct Node {
#ifdef DBG
int id;
#endif
std::vector<struct Edge> adj;
Node *f[LOG_MAXN + 1], *p;
int d, cnt;
};

struct Edge {
Node *s, *t;

Edge(Node *s, Node *t) : s(s), t(t) {}
};

inline void addEdge(Node *u, Node *v) {
u->adj.push_back(Edge(u, v));
v->adj.push_back(Edge(v, u));
}

void prepare(Node *u, Node *f = nullptr) {
u->f[0] = u->p = f;
u->d = (f ? f->d : 0) + 1;
for (int i = 1; i <= LOG_MAXN; i++) {
if (u->f[i - 1]) {
u->f[i] = u->f[i - 1]->f[i - 1];
}
}
for (Edge &e : u->adj) {
if (e.t == f) continue;
prepare(e.t, u);
}
}

inline Node *lca(Node *u, Node *v) {
if (u->d < v->d) std::swap(u, v);
if (u->d != v->d) {
for (int i = LOG_MAXN; i >= 0; i--) {
if (u->f[i] && u->f[i]->d >= v->d) {
u = u->f[i];
}
}
}
if (u != v) {
for (int i = LOG_MAXN; i >= 0; i--) {
if (u->f[i] != v->f[i]) {
u = u->f[i];
v = v->f[i];
}
}
return u->p;
}
return u;
}

long long dfs(Node *u, int m, Node *f = nullptr) {
long long ans = 0;

for (Edge &e : u->adj) {
if (e.t == f) continue;
ans += dfs(e.t, m, u);
if (e.t->cnt == 0) ans += m;
else if (e.t->cnt == 1) ans++;
}

if (u->p) u->p->cnt += u->cnt;

#ifdef DBG
printf("[%d]: %lld %d\n", u->id, ans, u->cnt);
#endif

return ans;
}

int main() {
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);

int n, m;
scanf("%d %d", &n, &m);

std::vector<Node> nodes(n + 1);

#ifdef DBG
for (int i = 1; i <= n; i++) nodes[i].id = i;
#endif
for (int i = 0; i < n - 1; i++) {
int u, v;
scanf("%d %d", &u, &v);
addEdge(&nodes[u], &nodes[v]);
}

prepare(&nodes[1]);

for (int i = 0; i < m; i++) {
int a, b;
scanf("%d %d", &a, &b);
Node *u = &nodes[a], *v = &nodes[b];
Node *f = lca(u, v);
u->cnt++, v->cnt++, f->cnt -= 2;
}

#ifdef DBG
for (int i = 1; i <= n; i++) printf("%d ", nodes[i].cnt);
putchar('\n');
#endif

printf("%lld\n", dfs(&nodes[1], m));

fclose(stdin);
fclose(stdout);

return 0;
}