「洛谷 P2371」墨墨的等式 - 最短路

题意描述

洛谷链接

墨墨突然对等式很感兴趣,他正在研究 \(\sum_{i = 1}^n a_i x_i = b\) 存在非负整数解的条件,他要求你编写一个程序,给定 \(n, a_{1 \dots n}, l, r\) 求出有多少 \(b \in [l, r]\) 可以使等式存在非负整数解。

解题思路

本题看似可以用完全背包可解,但由于该数据范围,肯定 T 飞。

于是我们可以采用图论方法来解这道题。

解决本题需要了解 同余最短路 这个知识。具体同余最短路是什么,我们可以用本题举例。

首先我们要做的是确定一个数 \(x\),表示为加其他数等同于加 \(x\)\(x\) 的余数。在这里可以定义 \(x = \min \limits_{1 \leq j \leq n} a_j\),且对于 \(\forall i \in N \wedge i < x\),计算出满足该题非负整数解条件的 \(b\) 的最小值为 \(\text{dis}_i\)

具体怎么计算 \(\text{dis}_i\),我们可以建立一张图。对于 \(\forall i \in N \wedge i < n\),建立连向 \((i + a_j) \mod x\) 的权值为 \(a_j\) 的边。这条边表明我们可通过将当前 \(\mod x\) 的数加上 \(a_j\) 转化成 \(\mod x = (i + a_j) \mod x\) 的数。

接下来我们从 \(0\) 出发求最短路。求出的最短路 \(\text{dis}_i\) 即为满足该数 \(\mod n = i\) 且满足该题非负整数解条件的 \(b\) 的最小值。

最后对于 \(b \mod x = i\),若我们总共有 \(\lfloor \frac{h - \text{dis}_i}{x} \rfloor + 1\) 种满足小于等于 \(h\) 且符合条件的情况。最后将所有余数的情况相加即为答案。

对于本题,我们可以先分别算出小于等于 \(r\) 的情况总数和小于等于 \(l - 1\) 的情况总数,由于两者重复,我们仅需将两情况数相减即为 \(l \leq b \leq r\) 的情况总数。

代码演示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <cstring>
#include <iostream>
#include <vector>
#include <queue>

const int MAXN = 500000;

std::vector< std::pair<int, long long> > e[MAXN + 1];
long long dis[MAXN + 1];

inline void add(int from, int to, int ver) {
e[from].push_back( std::make_pair(to, ver) );
}

bool spfa(int s) {
static std::queue<int> q;
static bool vis[MAXN + 1];
memset(dis, 0x3f, sizeof(dis));
memset(vis, false, sizeof(vis));
dis[s] = 0, vis[s] = true;
q.push(s);

while (!q.empty()) {
int u = q.front();
q.pop(), vis[u] = false;
for (auto ed : e[u]) {
int v = ed.first, w = ed.second;
if (dis[v] > dis[u] + w) {
dis[v] = dis[u] + w;
if (!vis[v]) q.push(v), vis[v] = true;
}
}
}

return true;
}

int main() {
int n;
long long l, r;
static int a[13];

std::cin >> n >> l >> r;

int minA = MAXN + 1;
for (int i = 1; i <= n; i++) {
std::cin >> a[i];
minA = std::min(minA, a[i]);
}

for (int i = 0; i < minA; i++) {
for (int j = 1; j <= n; j++) {
if (a[j] != minA) {
add(i, (i + a[j]) % minA, a[j]);
}
}
}

spfa(0);

long long ans = 0;
for (int i = 0; i < minA; i++) {
if (r >= dis[i]) ans += (r - dis[i]) / minA + 1;
if (l - 1 >= dis[i]) ans -= (l - 1 - dis[i]) / minA + 1;
}

std::cout << ans << std::endl;

return 0;
}