CodeForces - 1721E Prefix Function Queries(KMP优化)

cf中碰到了kmp的一种优化方式,特此记录。

思路

我们知道,kmp是一种在线算法。即便我们不断在字符串尾部添加字符,我们依然能保持O(n)的均摊复杂度。然而,该题在给予字符串s后,要求对q个字符串ti分别求next数组。这种时候我们的均摊分析就失效了。因为对于每一个新的t,我们可能都需要跳n步来求fail数组。这是不能接受的。为此引出了一个新的概念前缀函数自动机.
首先我们来看kmp的代码

1
2
3
4
5
6
7
8
9
10
11
12
// C++ Version By OI wiki
vector<int> prefix_function(string s) {
int n = (int)s.length();
vector<int> pi(n);
for (int i = 1; i < n; i++) {
int j = pi[i - 1];
while (j > 0 && s[i] != s[j]) j = pi[j - 1];//可以看到,如果我们之前进行了类似while回朔的话,我们在未来的某个位置i',可能还会跳到当前的位置i,然后再走一遍这个while循环。这是非常浪费时间的。所以这里是优化的核心,可以考虑通过dp的方式优化。
if (s[i] == s[j]) j++;
pi[i] = j;
}
return pi;
}

一个显然的想法,那么我们能直接使用一个一维的jump数组,记录i位置我们最终跳到了哪里。这样未来while回朔到i位置的时候,如果发现jump数组已经访问过就直接更新到jump(i),这样可以达到真正的O(1)更新。这个思路的优化方向是正确的,但是我们考虑这样一种情形:假定fail指针指向的是某一个a,子串形式为…ab…。当我们fail指针指向a,且子串形式也为…ab…时,我们能沿用此处的jump函数。但是如果是…ac…呢?显然此时jump函数就失效了。因此,jump数组的形式应该是jump(i, j),代表跳到位置i,待匹配字符为j时,最终移动的位置。

对于OI Wiki的O(n^2)优化到O(n)我觉得是比较好理解的。简而言之就是使用dp的思想,避免重复跳跃。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// OI Wiki O(n^2)
void compute_automaton(string s, vector<vector<int>>& aut) {
s += '#';
int n = s.size();
vector<int> pi = prefix_function(s);
aut.assign(n, vector<int>(26));
for (int i = 0; i < n; i++) {
for (int c = 0; c < 26; c++) {
int j = i;
while (j > 0 && 'a' + c != s[j]) j = pi[j - 1];
if ('a' + c == s[j]) j++;
aut[i][c] = j;
}
}
}

// OI Wiki O(n)
void compute_automaton(string s, vector<vector<int>>& aut) {
s += '#';
int n = s.size();
vector<int> pi = prefix_function(s);
aut.assign(n, vector<int>(26));
for (int i = 0; i < n; i++) {
for (int c = 0; c < 26; c++) {
if (i > 0 && 'a' + c != s[i])
aut[i][c] = aut[pi[i - 1]][c];
else
aut[i][c] = i + ('a' + c == s[i]);
}
}
}

这里在参考jiangly代码的时候,我发现他跳过了prefix_function这一步,直接在同一个循环中求jump和fail数组。核心代码在于
1
2
3
4
if(i&&str[i] == 'a'+j) {
jump[i][j] = i+1;
fail[i] = jump[fail[i-1]][j];
}

原理是fail(i-1)一定是一个最长公共前后缀,这样我们只要看fail(i-1)落到的位置(注意这里不同kmp的fail数组可能含义不一样。比如abcda,我实现的fail最后一个a的值是1,所以你能发现jump[fail[i-1]][j]正好就是在b上做比较。)这样就能方便地在一个循环中算出fail了。

最后,我们以一个表格作收尾,方便你更好地理解jump和fail的计算过程。

a b a c a b
jumpa 1 1 3 1 5 1
jumpb 0 2 0 2 0 6
jumpc 0 0 0 4 0 0
fail 0 0 1 0 1 2

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <bits/stdc++.h>

using namespace std;
using ll = long long;

const int MAXN = 1e6+15;
int fail[MAXN] = {0}, jump[MAXN][26] = {{0}};

void solve() {
string str;
cin>>str;
int n = str.length();
fail[0] = 0;
jump[0][str[0]-'a'] = 1;
for(int i = 1; i < n; i++) {
for(int j = 0; j < 26; j++) {
if(i&&str[i] == 'a'+j) {
jump[i][j] = i+1;
fail[i] = jump[fail[i-1]][j];
} else {
jump[i][j] = jump[fail[i-1]][j];
}
}
}
int t;
cin>>t;
while(t--) {
string q;
cin>>q;
int m = q.length();
for(int i = n; i < n+m; i++) {
for(int j = 0; j < 26; j++) {
if(q[i-n]=='a'+j) {
jump[i][j] = i+1;
fail[i] = jump[fail[i-1]][j];
} else {
jump[i][j] = jump[fail[i-1]][j];
}
}
cout<<fail[i]<<" \n"[i==n+m-1];
}
}
}

int main(){
#ifndef ONLINE_JUDGE
freopen("./input.txt","r",stdin);
freopen("./output.txt","w",stdout);
#endif // ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
// int T;
// cin>>T;
// while(T--) {
solve();
// }
return 0;
}