「洛谷 P4180」严格次小生成树 - 最小生成树 + 最近公共祖先

题意描述

洛谷链接

小 C 最近学了很多最小生成树的算法,Prim 算法、Kruskal 算法、消圈算法等等。正当小 C 洋洋得意之时,小 P 又来泼小 C 冷水了。小 P 说,让小 C 求出一个无向图的次小生成树,而且这个次小生成树还得是严格次小的,也就是说:如果最小生成树选择的边集是 \(E_M\),严格次小生成树选择的边集是 \(E_S\),那么需要满足:(\(value(e)\) 表示边 \(e\) 的权值) \(\sum_{e \in E_M}value(e)<\sum_{e \in E_S}value(e)\)

这下小 C 蒙了,他找到了你,希望你帮他解决这个问题。

对于 \(100\%\) 的数据, \(N\le 10^5\)\(M\le 3\times10^5\),边权 \(\in [0,10^9]\),数据保证必定存在严格次小生成树。

解题思路

首先我们求出最小生成树。然后对于每个未选中的边按权值从小到达遍历。对于每个未选中的边,设我们找到了边 \(u \leftrightarrow v\),我们可以在最小生成树中找到 \(u \rightarrow v\) 的路径,再将路径中的权值最大的边替换成该边。这样可以同时满足是一颗生成树且总权值相差最小。于是我们就得到了次小生成树的可能情况。最后在这些可能情况中取最小值即可。

但这样的次小生成树是非严格的,因为我们找到的边和将要替换的边权值可能相等。此时我们只需要用路径上的次大边替换就行了。

对于路径上的最大边和次大边,我们可以在预处理 LCA 的时候同时预处理 LCA 对应的路径的最大边和次大边即可。

整个问题的时间复杂度为 \(O(m \log m)\)

代码演示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include <cstdio>
#include <climits>
#include <vector>
#include <algorithm>

const int MAXN = 1e5;
const int MAXM = 3e5;
const int LOG_MAXN = 17;

struct Node {
std::vector<struct Edge> e;
Node *p, *f[LOG_MAXN + 1];
long long maxe[LOG_MAXN + 1][2];
int d, id;
} N[MAXN + 1];

struct Edge {
Node *s, *t;
int sid, tid;
long long w;
bool v;

Edge() {}
Edge(Node *s, Node *t, long long w) : s(s), t(t), w(w) {}

bool operator<(const Edge &other) const {
return w < other.w;
}
} E[MAXM + 1];

struct UnionFindSet {
int fa[MAXN + 1];

void init(int n) {
for (int i = 1; i <= n; i++) fa[i] = i;
}

int find(int x) {
return x == fa[x] ? x : fa[x] = find(fa[x]);
}

void merge(int x, int y) {
fa[find(x)] = find(y);
}
} ufs;

int n, m;

inline void addEdge(int s, int t, long long w) {
N[s].e.push_back(Edge(&N[s], &N[t], w));
N[t].e.push_back(Edge(&N[t], &N[s], w));
}

inline long long kruskal() {
ufs.init(n);
std::sort(E + 1, E + m + 1);

long long ans = 0;
int counts = 0;

for (int i = 1; i <= m; i++) {
Edge &e = E[i];
if (ufs.find(e.sid) == ufs.find(e.tid)) continue;
e.v = true;
addEdge(e.sid, e.tid, e.w);
ufs.merge(e.sid, e.tid);
ans += e.w;
if (++counts == n - 1) break;
}

return ans;
}

void prepare(Node *v, Node *f = NULL) {
v->f[0] = v->p = f;
v->d = (f ? f->d : 0) + 1;
v->maxe[0][1] = LLONG_MIN;
for (int i = 1; i <= LOG_MAXN; i++) {
if (v->f[i - 1]) {
v->f[i] = v->f[i - 1]->f[i - 1];
long long choice[4] = { v->maxe[i - 1][0], v->maxe[i - 1][1],
v->f[i - 1]->maxe[i - 1][0], v->f[i - 1]->maxe[i - 1][1] };
std::sort(choice, choice + 4);
v->maxe[i][0] = choice[3];
int p = 2;
while (p >= 0 && choice[p] == choice[3]) p--;
v->maxe[i][1] = (p == -1 ? LLONG_MIN : choice[p]);
}
}
for (Edge *e = &v->e.front(); e && e <= &v->e.back(); e++) {
if (e->t == f) continue;
e->t->maxe[0][0] = e->w;
prepare(e->t, v);
}
}

inline Node *lca(Node *u, Node *v) {
if (u->d < v->d) std::swap(u, v);
if (u->d != v->d) {
for (int i = LOG_MAXN; i >= 0; i--) {
if (u->f[i] && u->f[i]->d >= v->d) {
u = u->f[i];
}
}
}
if (u != v) {
for (int i = LOG_MAXN; i >= 0; i--) {
if (u->f[i] != v->f[i]) {
u = u->f[i];
v = v->f[i];
}
}
return u->p;
}
return u;
}

inline long long query(Node *v, Node *f, long long w) {
long long res = LLONG_MIN;
for (int i = LOG_MAXN; i >= 0; i--) {
if (v->f[i] && v->f[i]->d >= f->d) {
if (w != v->maxe[i][0]) res = std::max(res, v->maxe[i][0]);
else res = std::max(res, v->maxe[i][1]);
v = v->f[i];
}
}
return res;
}

int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) N[i].id = i;
for (int i = 1; i <= m; i++) {
scanf("%d %d %lld", &E[i].sid, &E[i].tid, &E[i].w);
E[i].s = &N[E[i].sid], E[i].t = &N[E[i].tid];
}

long long ans = LLONG_MAX;
long long sum = kruskal();
prepare(&N[1]);

for (int i = 1; i <= m; i++) {
if (!E[i].v) {
Edge *e = &E[i];
Node *f = lca(e->s, e->t);
long long sw = query(e->s, f, e->w);
long long tw = query(e->t, f, e->w);
if (std::max(sw, tw) > LLONG_MIN) ans = std::min(ans, sum - std::max(sw, tw) + e->w);
}
}

printf("%lld\n", ans);

return 0;
}