「洛谷 P6218」Round Numbers - 数位 DP

题意描述

洛谷链接

如果一个正整数的二进制表示中,\(0\) 的数目不小于 \(1\) 的数目,那么它就被称为「圆数」。

例如,\(9\) 的二进制表示为 \(1001\),其中有 \(2\)\(0\)\(2\)\(1\)。因此,\(9\) 是一个「圆数」。

请你计算,区间 \([l,r]\) 中有多少个「圆数」。

数据范围:\(1\le l,r\le 2\times 10^9\)

解题思路

本题显然数位 DP。我们设 \(a_n\) 为小于等于 \(n\) 中的圆数的个数,则答案为 \(a_r - a_{l - 1}\)。其中 \(a_n\) 可用数位 DP 解决。

对于 \(a_i\),我们首先将 \(n\) 使用二进制表示。设 \(n\)\(\text{len}\) 位,于是圆数至少要有 \(\text{need} = \lfloor \frac{\text{len} + 1}{2} \rfloor\)\(0\)。设 \(f_{i, \text{last}, \text{num}}\) 为剩下 \(i\) 位时,前 \(\text{len} - \text{i} - 1\) 位有 \(\text{last}\)\(0\),第 \(\text{len} - \text{i}\) 位为 \(\text{num}\),且小于剩余 \(i\) 位均为 \(0\) 的数时的圆数个数。显然我们可得到下列状态转移方程:

\[ f_{i, \text{last}, \text{num}} = \begin{cases} \begin{align} & \sum_{j = 1}^{\text{len - 1}}\sum_{k = \lfloor \frac{i + 1}{2} \rfloor}^{j - 1}{j - 1 \choose k} & i = \text{len} - 1 \wedge \text{num} = 1 \\ & \sum_{j = \text{need} - \text{last} - 1}^{i}{i \choose j} & i \ne \text{len} - 1 \wedge \text{num} = 1 \\ & 0 & \text{num} = 0 \end{align} \end{cases} \]

  • 公式 \((1)\) 表示遍历第 \(1\) 位的情况,由于这种情况不可含前导 \(0\) 所以要特殊处理;
  • 公式 \((2)\) 表示遍历到后面位数时该位为 \(1\) 的情况,此时 \(f_{i, \text{last}, \text{num}}\) 则计算当小于当前数的情况,即该位为 \(0\)、后面位数任意时圆数的个数;
  • 公式 \((3)\) 表示遍历到后面位数时该位为 \(0\) 的情况,由于不存在比 \(0\) 小的自然数,故 \(f_{i, \text{last}, \text{num}} = 0\)

\(n\) 的前 \(\text{len} - \text{i}\) 位的 \(0\) 的个数为 \(\text{last}_i\)、第 \(\text{len} - \text{i}\) 位为 \(\text{num}_i\)。于是我们可以预处理组合数,计算出答案 \(a_n = (\sum_{i = 1}^{\text{len}}\sum_{j = 0}^{\text{num}_i}{f_{i, \text{last}_i, j}}) + [n \in \text{圆数}]\)

代码演示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <cstdio>
#include <vector>

const int MAXN = 31;

int C[MAXN + 1][MAXN + 1];

inline void prepare() {
for (int i = 0; i <= MAXN; i++) {
for (int j = 0; j <= i; j++) {
if (j == 0) C[i][j] = 1;
else C[i][j] = C[i - 1][j] + C[i- 1][j - 1];
}
}
}

int solve(int n) {
if (n == 0) return 1;

std::vector<int> nums;
for (; n; n >>= 1) nums.push_back(n & 1);

int len = nums.size(), need = (len + 1) / 2;
int res = 0;
int last = 0;

for (int i = len - 1; i >= 0; i--) {
int x = nums[i];

if (x) {
if (i == len - 1) {
res++;
for (int j = 1; j <= i; j++) {
for (int k = (j + 1) / 2; k <= j - 1; k++) {
res += C[j - 1][k];
}
}
} else {
for (int j = need - last - 1; j <= i; j++) {
res += C[i][j];
}
}
} else last++;

if (i == 0 && last >= need) res++;
}

return res;
}

int main() {
prepare();

int l, r;

scanf("%d %d", &l, &r);

printf("%d\n", solve(r) - solve(l - 1));

return 0;
}