BZOJ 1009 - 「HNOI2008」GT考试

题目链接:BZOJ 1009

设准考证号为目标串\(S\),不吉利的数字串为模式串\(T\)。 考虑对\(S\)\(T\)进行模式匹配,\(f(i,j)\)表示匹配到目标串的第\(i\)位和模式串的第\(j\)位,并且之前没有出现过完全匹配的子串,这种情况下目标串前\(i\)位的组成方案数量。那么显然答案为\(\sum_{i=0}^{m-1} f(n, i)\)。 这时候有两种情况:

  • \(s_i = t_j\)\(f(i, j) = f(i - 1, j - 1)\)
  • \(s_i \neq t_j\), \(f(i, j) = \sum f(i - 1,k)[{\rm fail}(k) = j - 1]\)

所以\(f(i,j)\)要么从\(f(i-1,j-1)\)转移,要么从能跳转到\(f(i-1,j-1)\)\(f(i-1,k)\)转移(基于\(j-1\)结尾的后缀和\(i-1\)为结尾的后缀已经匹配完了)。

考虑\(n \leq 10^9\),这个式子要继续化简。我们设一个新的函数\(g(u, v)\),表示在\(u\)位置的所有可能取值中,\(u\)能转移到\(v\)的方案数,显然\(g(u,v)\)由两部分构成:一个是\(v=u + 1\)并且\(v\)匹配成功,另一个是\({\rm fail}(v) = u - 1\)。那么我们可以得到:

\[f(i, j) = \sum_{k=0}^{m-1} f(i - 1, k) \times g(k, j)\]

\(k\)的范围\([0,m-1]\)表示合法串中不能由完整的模式串匹配,求和部分就是统计转移的贡献。

此时我们可以把上式转化为 \[ \begin{aligned} & \begin{bmatrix} f(i-1, 0) & f(i-1,1) & \cdots & f(i-1, m-1) \\ f(i-1, 0) & f(i-1,1) & \cdots & f(i-1, m-1) \\ \vdots & \vdots & \ddots & \vdots \\ f(i-1, 0) & f(i-1,1) & \cdots & f(i-1, m-1)    \end{bmatrix} \cdot g \\ = & \begin{bmatrix} f(i, 0) & f(i,1) & \cdots & f(i, m-1) \\ f(i, 0) & f(i,1) & \cdots & f(i, m-1) \\ \vdots & \vdots & \ddots & \vdots \\ f(i, 0) & f(i,1) & \cdots & f(i, m-1)    \end{bmatrix} \end{aligned}   \]

因此可以使用矩阵快速幂进行优化,渐进时间复杂度\(O(m^3 \log n)\)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MAX_N = 20 + 5;
int n, m, k, fail[MAX_N];
char s[MAX_N];
struct Matrix {
int _[MAX_N][MAX_N];
Matrix() {
memset(_, 0, sizeof _);
}
friend Matrix operator * (const Matrix &u, const Matrix &v) {
Matrix ret;
for (int i = 0; i < m; ++i) {
for (int j = 0; j < m; ++j) {
for (int t = 0; t < m; ++t) {
ret._[i][j] += u._[i][t] * v._[t][j];
ret._[i][j] %= k;
}
}
}
return ret;
}
} g, f;
Matrix getPow(Matrix x, int y) {
Matrix ret;
for (int i = 0; i <= m; ++i) ret._[i][i] = 1;
while (y) {
if (y & 1) ret = ret * x;
x = x * x;
y >>= 1;
}
return ret;
}
int main() {
scanf("%d%d%d", &n, &m, &k);
scanf("%s", s + 1);
for (int i = 2, j = 0; i <= m; ++i) {
while (j && s[i] != s[j + 1]) j = fail[j];
if (s[i] == s[j + 1]) ++j;
fail[i] = j;
}
// for (int i = 1; i <= m; ++i) printf("fail[%d] = %d\n", i, fail[i]);
for (int i = 0; i < m; ++i) {
for (char c = '0'; c <= '9'; ++c) {
int j = i;
while (j && c != s[j + 1]) j = fail[j];
if (s[j + 1] == c) ++j;
g._[i][j] = (g._[i][j] + 1) % k;
}
}
f._[0][0] = 1;
f = f * getPow(g, n);
int ans = 0;
for (int i = 0; i < m; ++i) ans = (ans + f._[0][i]) % k;
printf("%d\n", ans);
return 0;
}