「SCOI2005」互不侵犯 - 状压 DP

题意描述

洛谷链接

LibreOJ 链接

\(N \times N\) 的棋盘里面放 \(K\) 个国王,使他们互不攻击,共有多少种摆放方案。国王能攻击到它上下左右,以及左上左下右上右下八个方向上附近的各一个格子,共 \(8\) 个格子。

解题思路

很明显这是一道状压 DP 的题目。我们可以定义数组 \(f_{i, j, l}\),状态表示中 \(i\) 表示行数,\(j\) 表示已经放下的国王数,\(l\) 表示当前这行的国王放置状态。其中 \(l\) 使用二进制表示子集存储,\(1\) 表示放置国王。整个 \(f\) 表示当前状态的可能数。

显然我们可以推出下列状态转移方程(\(\text{cnt}_x\) 表示 \(x\) 状态下当前行的国王数,\(\text{reach}\) 代表添加一行可达到 \(l\) 的情况的集合):

\[ f_{i, j + \text{cnt}_m, l} = \sum_{\text{reach}} f_{i - 1, j, m} \quad j \in [\text{cnt}_l, k - \text{cnt}_j] \]

对于 \(f\) 的初始化,我们仅需要做的是将 \(i = 1\) 的所有合法的序列赋值为 \(1\) 即可。

代码演示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <cstring>
#include <iostream>

const int MAXN = 9;

int n, k;
long long f[MAXN + 1][MAXN * MAXN + 1][1 << MAXN];

inline int counts(int x) {
int ans = 0;
for (int i = 0; (1 << i) <= x; i++) {
if ((x >> i) & 1) {
ans++;
}
}
return ans;
}

inline void init() {
memset(f, 0, sizeof(f));

for (int i = 0; i < (1 << n); i++) {
bool con = true;
for (int each = 1; each < n - 1; each++) {
if (((i >> each) & 1) && (((i >> (each - 1)) & 1) || ((i >> (each + 1)) & 1))) {
con = false;
break;
}
}

int cnt = counts(i);
if (con && cnt <= k) f[1][cnt][i] = 1;
}
}

inline void dp() {
init();

for (int i = 2; i <= n; i++) {
for (int j = 0; j < (1 << n); j++) {
bool con = true;
for (int each = 1; each < n - 1; each++) {
if (((j >> each) & 1) & (((j >> (each - 1)) & 1) || ((j >> (each + 1)) & 1))) {
con = false;
break;
}
}

int cntJ = counts(j);
if (con && cntJ <= k) {
for (int l = 0; l < (1 << n); l++) {
con = true;
if ((j & 1) && ((l & 1) || ((l >> 1) & 1))) continue;
if (((j >> (n - 1)) & 1) && (((l >> (n - 1)) & 1) || ((l >> (n - 2)) & 1))) continue;
for (int each = 1; each < n - 1; each++) {
if (((j >> each) & 1) && (((l >> each) & 1) || ((l >> (each - 1)) & 1) || ((l >> (each + 1)) & 1))) {
con = false;
break;
}
}

int cntL = counts(l);
if (con && cntJ + cntL <= k) {
for (int m = cntL; m <= k - cntJ; m++) {
f[i][m + cntJ][j] += f[i - 1][m][l];
}
}
}
}
}
}
}

inline long long getAns() {
long long ans = 0;

for (int i = 0; i < (1 << n); i++) {
ans += f[n][k][i];
}

return ans;
}

int main() {

std::cin >> n >> k;

dp();

std::cout << getAns() << std::endl;
return 0;
}