「Codeforces 1746D」Paths on the Tree - 贪心

题意描述

Codeforces 链接

给定一个 \(n\) 个节点的树,其节点被标记为 \(1\)\(n\),而且该树的根为 \(1\),另外也给定一个积分序列 \(s\)

如果下列两个条件都满足,则我们称路径集合k可用:

  • 该集合内所有路径从 \(1\) 开始

  • \(c_i\) 为覆盖节点 \(i\) 的路径数量,对于每对拥有同个父节点的节点 \((u,v)\),要求\(|c_u-c_v|\) 小于等于1

对于每个路径集合,其权值被定义为 \(\sum\limits_{i=1}^n{c_i s_i}\)

显而易见,每组数据至少有一个可用的路径集合,找出所有可用路径集合中的最大权值

解题思路

显然为了让答案最大,每条路径都会走到叶子节点。于是我们可以考虑每条路径对每一层的影响。显然对于度为 \(a\) 的节点 \(u\) 被覆盖 \(k\) 次,每个节点要么被覆盖 \(\lfloor \frac{k}{a} \rfloor\) 次,要么被覆盖 \(\lfloor \frac{k}{a} \rfloor + 1\) 次,且被被覆盖 \(\lfloor \frac{k}{a} \rfloor + 1\) 次的节点有 \(k \bmod a\) 个。于是我们可以先将每个点覆盖 \(\lfloor \frac{k}{a} \rfloor\) 次。接下来再处理被覆盖 \(\lfloor \frac{k}{a} \rfloor + 1\) 次的情况。

处理被覆盖 \(\lfloor \frac{k}{a} \rfloor + 1\) 次的情况时,我们可以从从叶子到根的方向考虑。我们直接进行 DFS 寻找剩余路径。显然对于边 \(u \rightarrow v\),到达 \(v\) 的路径中只有一条对 \(u\) 有效,为使答案最大,我们取权值最大的一条路径,其余的路径到 \(u\) 就截止了,记录进答案即可。

代码演示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <cstdio>
#include <vector>
#include <queue>

struct Node {
std::vector<Node *> adj;
int deg, remain;
long long s;
std::priority_queue<long long> q;
};

long long ans = 0;

inline void addEdge(Node *u, Node *v) {
u->adj.push_back(v);
u->deg++;
}

void dfs1(Node *u, long long k) {
if (!u->adj.empty()) {
int size = u->adj.size();
ans += k / size * size * u->s;
for (Node *v : u->adj) dfs1(v, k / size);
u->remain = k % size;
} else ans += k * u->s;
}

void dfs2(Node *u) {
if (u->deg == 0) u->q.push(u->s);
for (Node *v : u->adj) {
dfs2(v);
if (!v->q.empty()) u->q.push(v->q.top() + u->s);
while (!v->q.empty()) v->q.pop();
}
while (u->remain--) {
ans += u->q.top();
u->q.pop();
}
}

void solve() {
int n, k;
scanf("%d %d", &n, &k);
std::vector<Node> nodes(n);
for (int i = 1; i < n; i++) {
int p;
scanf("%d", &p);
addEdge(&nodes[--p], &nodes[i]);
}
for (int i = 0; i < n; i++) scanf("%lld", &nodes[i].s);

ans = 0;
dfs1(&nodes[0], k);
dfs2(&nodes[0]);

printf("%lld\n", ans);
}

int main() {
int t;

scanf("%d", &t);

while (t--) solve();

return 0;
}