「JSOI2007」文本生成器 - AC 自动机 + DP

题意描述

洛谷链接

LibreOJ 链接

JSOI 交给队员 ZYX 一个任务,编制一个称之为“文本生成器”的电脑软件:该软件的使用者是一些低幼人群,他们现在使用的是 GW 文本生成器 v6 版。

该软件可以随机生成一些文章——总是生成一篇长度固定且完全随机的文章。 也就是说,生成的文章中每个字符都是完全随机的。如果一篇文章中至少包含使用者们了解的一个单词,那么我们说这篇文章是可读的(我们称文章 \(s\) 包含单词 \(t\),当且仅当单词 \(t\) 是文章 \(s\) 的子串)。但是,即使按照这样的标准,使用者现在使用的 GW 文本生成器 v6 版所生成的文章也是几乎完全不可读的。ZYX 需要指出 GW 文本生成器 v6 生成的所有文本中,可读文本的数量,以便能够成功获得 v7 更新版。你能帮助他吗?

答案对 \(10^4 + 7\) 取模。

解题思路

该题问题不好求解,我们可转化为统计不含所有要求字符串的文本数量,用总数量减去该数即为答案。

于是我们可以考虑建立 AC 自动机,并在 AC 自动机上进行 DP,碰到 end 标记就不转移即可。

在建立 AC 自动机的时候,我们可以编写 end[trie[u][i]] += end[trie[fail[u]][i]],以使 fail 节点继承原本节点的字符串节点标记,即可使 DP 遇到所有要求字符串均不统计。

\(f_{i, j}\) 表示目前在文本第 \(i\) 位,且在 AC 自动机匹配时在节点 \(j\) 时符合条件的总情况。则有下列转移方程(\(\text{trie}_{j, k}\) 为在 AC 自动机节点 \(j\) 上且文本下一位为 \(k\) 时匹配的节点):

\[ f_{i + 1, \text{trie}_{j, k}} = \begin{cases} 0 & k \in [\text{A}, \text{Z}] \wedge \text{end}_k > 0 \\ f_{i + 1, \text{trie}_{j, k}} + f_{i, j} & k \in [\text{A}, \text{Z}] \wedge \text{end}_k = 0 \end{cases} \]

最后统计 \(26^n - \sum_{i \in \text{AC}} f_{n, i}\)\(\text{AC}\) 为 AC 自动机所有节点的集合,\(n\) 为文本长度)即为答案。

代码演示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <cstdio>
#include <cstring>
#include <queue>

const int MAXN = 60;
const int MAXM = 100;
const int MAXS = 100;
const int MOD = 1e4 + 7;

int trie[MAXN * MAXS + 1][26], tot = 0;
int end[MAXN * MAXS + 1], fail[MAXN * MAXS + 1];
int f[MAXM + 1][MAXN * MAXS + 1];

void insert(char *s) {
int len = strlen(s), p = 0;
for (int i = 0; i < len; i++) {
if (!trie[p][s[i] - 'A']) trie[p][s[i] - 'A'] = ++tot;
p = trie[p][s[i] - 'A'];
}
end[p]++;
}

void build() {
std::queue<int> q;
for (int i = 0; i < 26; i++) {
if (trie[0][i]) {
q.push(trie[0][i]);
}
}
while (!q.empty()) {
int u = q.front();
q.pop();
for (int i = 0; i < 26; i++) {
if (trie[u][i]) {
fail[trie[u][i]] = trie[fail[u]][i];
end[trie[u][i]] += end[trie[fail[u]][i]];
q.push(trie[u][i]);
} else trie[u][i] = trie[fail[u]][i];
}
}
}

int pow(int a, int b) {
int ans = 1;
for (; b; b >>= 1, a = a * a % MOD) if (b & 1) ans = ans * a % MOD;
return ans;
}

int main() {
int n, m;
scanf("%d %d", &n, &m);

while (n--) {
static char s[MAXS + 1];
scanf("%s", s);
insert(s);
}
build();

f[0][0] = 1;
for (int i = 0; i < m; i++) {
for (int j = 0; j <= tot; j++) {
for (int k = 0; k < 26; k++) {
if (!end[trie[j][k]]) {
f[i + 1][trie[j][k]] = (f[i + 1][trie[j][k]] + f[i][j]) % MOD;
}
}
}
}

int ans = pow(26, m);
for (int i = 0; i <= tot; i++) ans = (ans - f[m][i] + MOD) % MOD;
printf("%d\n", ans);

return 0;
}