0%

疫情还没结束,但是学校教学工作已经陆陆续续在线展开了。于是我们网络编程老师还没上课就先扔了一道题,要我们写个带Lexer和Parser的算术解释器,实现变量和赋值的功能。还没学过编译原理的我就这样被迫营业了。

花了四天的时间,看了很多集哈工大的编译原理的MOOC(讲得还不错哦),加上写代码调BUG,基本完成了。效果大概是这样的:

写篇博文来记录一下被迫营业的心路历程。

项目代码 github: zhaoyw1999/arithmetic-interpreter

首先我们要知道写一个解释器的大致思路是对输入的代码做词法分析和语法分析,对应的是Lexer模块和Parser模块。

在Lexer部分,我们需要把一串长长的字符串分成一个个Token,作为下面Parser分析的最小单位。这个解释器语法所需要的Token一共有9种,分别是加减乘除、赋值、变量名、数字和左右括号,可以用枚举类型TokenType表示(在这里我还加了一个结束符END_FLAG,但是后来没有用到),这里是代码。

1
2
3
4
5
6
7
8
9
10
11
12
enum TokenType {
IDENTIFIER,
NUMBER,
L_BRACKET,
R_BRACKET,
ASSIGNMENT,
ADD,
SUB,
MUL,
DIV,
END_FLAG
};

对于每一个Token,我们需要记录下这个Token的类型和它的字面量,其实最重要的是数字和变量名的字面量,因为我们在后面计算的时候,会后序遍历语法树,就需要用到叶子结点的字面量。在设计Token类型的时候,设计一个to_string方法,后续输出Token。

1
2
3
4
5
6
7
8
9
10
struct Token {
TokenType token_type;
TokenValue token_value;

Token();
Token(TokenType token_type_, TokenValue token_value_);
~Token();

string to_string();
};

下面我们需要设计一个分词器,输入一个字符串,输出一个Token流。在设计Token流的时候,我们可以使用vector或者deque。这一点是我在整个解释器实现的差不多的时候才想到的,一开始设计的时候用了vector,为什么后来想到deque呢——因为在后续写Parser的过程中,是从左到右推导的,正好deque可以模拟出每次从流中取一个元素,不用自己设计一个下标计数器。

在分词器的设计上,有两种方法。一种是直接分词,适用于Token类型比较少、形式比较简单的情况;一种是设计状态机分词,适用于Token种类多、形式复杂的情况。这里我们使用前一种方法。

分词器的设计如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
using VectorStream = vector <Token>;
using DequeStream = deque <Token>;

class TokenStream {
private:
string origin_stream;
VectorStream stream_vector;
DequeStream stream_deque;

void append_token(Token to_add);
bool is_alpha_underline(char c);
bool is_alpha(char c);
bool is_digit(char c);
bool is_alpha_underline_digit(char c);

public:
TokenStream();
TokenStream(string origin_stream_);
~TokenStream();

void set_origin_stream(string origin_stream_);
void tokenize();
VectorStream get_tokenized_vector();
DequeStream get_tokenized_deque();
};

因为我原本是基于vector设计的,在使用待分词的字符串初始化TokenStream之后,调用它的tokenize方法实现分词,装入一个VectorStream中,再调用用get_tokenized_vector就可以获得这个分词之后的结果。

下面到了Parser的部分。我们得到了Token的序列之后,要把Token的序列解析成一棵语法树。我们首先要列出表达式的巴斯科范式:

\[ \begin{aligned} e &: {\rm expression} \\\ t &: {\rm term} \\\ f &: {\rm factor} \\\ e & \rightarrow e + t | e - t | t \\\ t & \rightarrow t * f | t / f | f \\\ f & \rightarrow x | (e) \\\ x & \rightarrow {\rm number} | {\rm id} \\\ \end{aligned} \]

消除左递归之后:

\[ \begin{aligned} e & \rightarrow t | t e' \\\ e' & \rightarrow + t e' | - t e' | \epsilon \\\ t & \rightarrow f | f t' \\\ t' & \rightarrow * f t' | / f t' | \epsilon \\\ f & \rightarrow x | (e) \\\ x & \rightarrow {\rm number} | {\rm id} \\\ \end{aligned} \]

对于变量赋值的功能,我们可以用\({\rm id} = e\)来表示,在解析的时候,先取掉\({\rm id}=\),解析右边的\(e\),再把变量名到值的映射存在一个map里即可。这样可以直接将赋值语句的解析转到前面我们推导的表达式解析。

我们需要建立一棵语法树,我们刚刚的所有运算均为二元运算,所以语法树就表现为一个二叉树,括号不建入语法树内,那么我们建立的语法树一定只有出度为0和出度为2的节点。这是节点的设计。

1
2
3
4
5
6
7
8
9
10
11
12
13
struct SyntaxTreeNode {
Token token;
SyntaxTreeNode *l_son;
SyntaxTreeNode *r_son;

SyntaxTreeNode();
SyntaxTreeNode(
Token token_,
SyntaxTreeNode *l_son_,
SyntaxTreeNode *r_son_
);
~SyntaxTreeNode();
};

在计算表达式的值的时候我后序遍历这颗二叉树即可。这里是Parser类的设计,其中中序遍历和先序遍历是用来调试使用的,其实应该标记为private。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
using std::map;
using Integer = long long;
using VariableMap = map <string, Integer>;

class Parser {
private:
SyntaxTreeNode *syntax_tree_root;
VectorStream token_stream;
VariableMap variable_map;
int current_position;

SyntaxTreeNode *parse_expression();
// SyntaxTreeNode *parse_as_expression();
SyntaxTreeNode *parse_term();
// SyntaxTreeNode *parse_md_term();
SyntaxTreeNode *parse_factor();
SyntaxTreeNode *parse_number();
SyntaxTreeNode *parse_identifier();

bool is_assignment();
void dfs_destroy(SyntaxTreeNode *root);
Integer dfs_calculate(SyntaxTreeNode *root);

public:
Parser();
Parser(VectorStream token_stream_);
~Parser();

void parse();

void get_preorder_traversal(
VectorStream &tar,
SyntaxTreeNode *root
);
void get_preorder_traversal(VectorStream &tar);

void get_inorder_traversal(
VectorStream &tar,
SyntaxTreeNode *root
);

void get_inorder_traversal(VectorStream &tar);

std::pair <bool, Integer> calculate(VectorStream token_stream_);
};

要注意的是我们每次运行这个程序都只建立了一个Parser的实例,每次执行calculate的时候都需要先晴空万里语法树、再执行parse,但是variable_map是不需要每次清空的,因为后续的表达式可能会查询变量的值。

除了这些主要的模块之外,还需要设计一些异常处理,用来处理例如括号不匹配、使用了未定义的变量之类的异常状况。

回过头来看这四天写的代码,我觉得在设计字面值到Integer转换的时候还可以改进,因为一个好的设计,可以吧目前的Integer换成BigInteger之后依然不需要修改原来的模块。正好最近在学习Design Pattern,讲到了设计软件要应对变化,找到变与不变的分界点,这一点需要在实践中慢慢感悟。当这个语言系统变得更加复杂之后,在做词法分析的时候我们就需要用到状态机,语法分析的时候我们要考虑到底使用什么样的分析方法,希望在这学期学完编译原理之后能写出一个更复杂的解释器。

题目链接:LeetCode 6 - Z 字形变换

我们可以把 numRows + numRows - 2 个字符分为一组,记这个组的大小为 groupSize。然后通过分类讨论计算出当前这个元素变换后的行号和列号,然后做双关键字排序即可。

渐进时间复杂度 \(O(n \log n)\),其中 \(n\)s 的长度。

注意特别判断numRows = 1的情况。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
struct Coordinate {
char c;
int x;
int y;

bool operator < (const Coordinate &rhs) const {
return (x == rhs.x) ? (y < rhs.y) : (x < rhs.x);
}
};

class Solution {
private:
vector <Coordinate> v;

int getRId(int x, int groupSize, int numRows) {
x %= groupSize;
if (x < numRows) {
return x;
} else {
return (numRows - 1) - (x - numRows + 1);
}
}

int getCId(int x, int groupSize, int numRows) {
int ans = x / groupSize * (numRows - 1);
x %= groupSize;
if (x < numRows) {
return ans;
} else {
return ans + x - numRows + 1;
}
}

public:
string convert(string s, int numRows) {
if (numRows == 1) return s;
int groupSize = numRows * 2 - 2;
for (int i = 0; i < s.size(); ++i) {
int rId = getRId(i, groupSize, numRows);
int cId = getCId(i, groupSize, numRows);
v.push_back({s[i], rId, cId});
}
sort(v.begin(), v.end());
string ret = "";
for (auto e: v) {
ret.push_back(e.c);
}
return ret;
}
};

题目链接:PAT - Advanced Level - 1100

给定一个 \(r\) 进制的数字 \(x\),和一个未知进制的数字 \(y\),求是否存在一个 \(t\),使得 \(y\)\(t\) 进制下等于 \(r\) 进制的 \(x\)

题解:二分答案,渐进时间复杂度 \(O( \max \{ |s_x|, |s_y| \} \log n)\)

坑点:原本我以为 \(t_{\max} = 36\),其实可能更大。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
w = {}
def dictInit():
for i in range(0, 10):
w[str(i)] = i
for i in range(10, 36):
w[chr(ord('a') + i - 10)] = i
return w

def cal(ori, h):
r = 0
for c in ori:
r = r * h + w[c]
return r

def binarySearch(l, r, ori, tar):
while l < r:
mid = (l + r) >> 1
tmp = cal(ori, mid)
if tmp < tar:
l = mid + 1
else:
r = mid
return l if cal(ori, l) == tar else 'Impossible'

def main():
x, y, tag, r = map(str, input().split())
r = int(r)
if tag != '1':
x, y = y, x
dictInit()

value = 0
for c in x:
value = value * r + w[c]

minAns = 0
for c in y:
minAns = max(minAns, w[c] + 1)

print(binarySearch(minAns, int(1E18), y, value))

if __name__ == "__main__":
main()

题目链接:PAT – Advanced Level – 1147

直接模拟即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MAX_N = 1000 + 5;
int m, n;
int a[MAX_N];
bool maxHeap, minHeap;
inline int ls(int o) {
return o << 1;
}
inline int rs(int o) {
return o << 1 | 1;
}
void judgeMaxHeap(int o) {
if (ls(o) <= n && a[ls(o)] > a[o]) return (maxHeap = false), void();
if (rs(o) <= n && a[rs(o)] > a[o]) return (maxHeap = false), void();
if (ls(o) <= n) judgeMaxHeap(ls(o));
if (rs(o) <= n) judgeMaxHeap(rs(o));
}
void judgeMinHeap(int o) {
if (ls(o) <= n && a[ls(o)] < a[o]) return (minHeap = false), void();
if (rs(o) <= n && a[rs(o)] < a[o]) return (minHeap = false), void();
if (ls(o) <= n) judgeMinHeap(ls(o));
if (rs(o) <= n) judgeMinHeap(rs(o));
}
void dfs(int o, vector <int> &tar) {
if (ls(o) <= n) dfs(ls(o), tar);
if (rs(o) <= n) dfs(rs(o), tar);
tar.push_back(a[o]);
}
int main() {
scanf("%d%d", &m, &n);
for (int cs = 1; cs <= m; ++cs) {
for (int i = 1; i <= n; ++i) {
scanf("%d", a + i);
}
maxHeap = minHeap = true;
judgeMaxHeap(1);
judgeMinHeap(1);
if (maxHeap) puts("Max Heap");
else if (minHeap) puts("Min Heap");
else puts("Not Heap");
vector <int> v;
dfs(1, v);
for (auto i = 0; i < v.size(); ++i) {
printf("%d%c", v[i], i == v.size() - 1 ? '\n' : ' ');
}
}
return 0;
}

题目链接:PAT – Advanced Level – 1151

通过先序和中序序列建树,然后倍增求LCA,渐进时间复杂度 \(O((n+m) \log n)\)

坑点:要离散化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 10000 + 5;
int n, m;
vector <int> g[MAX_N];
int iOrd[MAX_N], pOrd[MAX_N];
int f[MAX_N][20], d[MAX_N];
int lg[MAX_N];
void build(int iL, int iR, int pL, int pR) {
int pos = -1, base = pOrd[pL];
for (int i = iL; i <= iR; ++i) {
if (iOrd[i] == base) {
pos = i;
break;
}
}
int iLeftL = iL, iLeftR = pos - 1, iRightL = pos + 1, iRightR = iR;
int
pLeftL = pL + 1,
pLeftR = pL + iLeftR - iLeftL + 1,
pRightL = pL + iLeftR - iLeftL + 2,
pRightR = pR;
if (iLeftL <= iLeftR) {
g[base].push_back(pOrd[pLeftL]);
build(iLeftL, iLeftR, pLeftL, pLeftR);
}
if (iRightL <= iRightR) {
g[base].push_back(pOrd[pRightL]);
build(iRightL, iRightR, pRightL, pRightR);
}
}
set <int> nodeSet;
inline bool testError(int u, int v) {
auto ok = [] (const int &x) -> bool { return nodeSet.find(x) != nodeSet.end(); };
if (!ok(u) && ok(v)) {
printf("ERROR: %d is not found.\n", u);
} else if (ok(u) && !ok(v)) {
printf("ERROR: %d is not found.\n", v);
} else if (!ok(u) && !ok(v)) {
printf("ERROR: %d and %d are not found.\n", u, v);
} else {
return false;
}
return true;
}
void dfs(int x, int fa) {
f[x][0] = fa; d[x] = d[fa] + 1;
for (int i = 1; i <= lg[d[x]]; ++i) {
f[x][i] = f[f[x][i - 1]][i - 1];
}
for (auto &e: g[x]) {
if (e != fa) dfs(e, x);
}
}
int getLCA(int u, int v) {
if (d[u] < d[v]) swap(u, v);
while (d[u] > d[v]) u = f[u][lg[d[u] - d[v]] - 1];
if (u == v) return u;
for (int k = lg[d[u]] - 1; k >= 0; --k) {
if (f[u][k] != f[v][k]) {
u = f[u][k];
v = f[v][k];
}
}
return f[u][0];
}
vector <int> disc;
inline int getId(int x) {
return lower_bound(disc.begin(), disc.end(), x) - disc.begin() + 1;
}
int main() {
scanf("%d%d", &m, &n);
for (int i = 1; i <= n; ++i) scanf("%d", iOrd + i), disc.push_back(iOrd[i]), nodeSet.insert(iOrd[i]);
for (int i = 1; i <= n; ++i) scanf("%d", pOrd + i);
sort(disc.begin(), disc.end());
for (int i = 1; i <= n; ++i) iOrd[i] = getId(iOrd[i]);
for (int i = 1; i <= n; ++i) pOrd[i] = getId(pOrd[i]);
build(1, n, 1, n);
for (int i = 1; i <= n; ++i) lg[i] = lg[i - 1] + ((1 << lg[i - 1]) == i);
dfs(pOrd[1], 0);
for (int i = 1; i <= m; ++i) {
int u, v;
scanf("%d%d", &u, &v);
int ru = getId(u), rv = getId(v);
if (testError(u, v)) continue;
int queryLCA = disc[getLCA(ru, rv) - 1];
if (queryLCA == u) printf("%d is an ancestor of %d.\n", u, v);
else if (queryLCA == v) printf("%d is an ancestor of %d.\n", v, u);
else printf("LCA of %d and %d is %d.\n", u, v, queryLCA);
}
return 0;
}

题目链接:PAT – Advanced Level – 1155

DFS。

吐槽:同样是30分的题差别怎么这么大啊?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MAX_N = 1000 + 5;
int n, a[MAX_N];
vector < vector <int> > path;
vector <int> stk;
void dfs(int x) {
stk.push_back(a[x]);
if (x * 2 + 1 <= n) dfs(x * 2 + 1);
if (x * 2 <= n) dfs(x * 2);
else path.push_back(stk);
stk.pop_back();
}
bool minHeap = true, maxHeap = true;
void check() {
for (auto &v: path) {
for (unsigned i = 1; i < v.size(); ++i) {
if (v[i] > v[i - 1]) maxHeap = false;
if (v[i] < v[i - 1]) minHeap = false;
}
for (int i = 0; i < v.size(); ++i) {
printf("%d%c", v[i], i == v.size() - 1 ? '\n' : ' ');
}
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) scanf("%d", a + i);
dfs(1);
check();
if (!minHeap && !maxHeap) {
puts("Not Heap");
} else if (minHeap) {
puts("Min Heap");
} else if (maxHeap) {
puts("Max Heap");
}
return 0;
}

题目链接:Gym 102361F

考虑每个简单环内至少取一条边,如果一个简单环有 \(x\) 条边,那么这个简单环的方案数为 \(\sum_{i=1}^{x} {x \choose i} = 2^x - 1\),对于非简单环上的边(即剩余的边)如果有 \(r\) 条,那么方案数为 \(\sum_{i=1}^{r} {r \choose i} = 2^r\),最后分步计数即可。

需要注意的是,这里没有保证整张图是联通的,只保证每个联通块是一个仙人掌,所以不能只DFS一个点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int P = 998244353;
inline LL getPow(LL x, LL y) {
LL ret = 1;
while (y) {
if (y & 1) ret = ret * x % P;
x = x * x % P;
y >>= 1;
}
return ret;
}
const int MAX_N = 300000 + 5;
int n, m;
vector <int> g[MAX_N];
vector <int> cir;
stack <int> stk;
int pos[MAX_N];
int ins[MAX_N];
int vis[MAX_N];
int eCnt;
void dfs(int x, int fa) {
// printf("call dfs(x = %d, fa = %d)\n", x, fa);
stk.push(x);
ins[x] = 1;
pos[x] = stk.size();
for (auto const &e: g[x]) {
if (e == fa) {
continue;
} else if (ins[e]) {
cir.push_back(pos[x] - pos[e] + 1);
eCnt += (pos[x] - pos[e] + 1);
} else if (!vis[e]) {
dfs(e, x);
}
}
stk.pop();
ins[x] = 0;
vis[x] = 1;
}
int main() {
scanf("%d%d", &n, &m);
// build graph
int u, v;
for (int i = 1; i <= m; ++i) {
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
// solve
// dfs(1, -1);
for (int i = 1; i <= n; ++i) if (!vis[i]) dfs(i, -1);
LL ans = 1;
for (unsigned i = 0; i < cir.size(); ++i) {
LL con = (((getPow(2, cir[i]) - 1) % P) + P) % P;
ans = ans * con % P;
}
ans = ans * getPow(2, m - eCnt) % P;
printf("%lld\n", ans);
return 0;
}

上次写实验报告的时候从网上剽来一份代码高亮的配置,其中有一句 escapeinside=**,当时不知道是什么意思。这次写 Linux 实验的时候同样使用了这份高亮配置,结果吞掉了一些符号,记录一下这个坑。

escapeinside 的意思是添加注释暂时离开 listings 的环境,也就是说编译之后不会在代码块中出现,例如 escapeinside={(*}{*)} 就表示这个「暂时离开」的部分以 (* 开头,以 *) 结尾:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
\begin{lstlisting}[language={[x86masm]Assembler}]
(* 这里是\LaTeX *)
; function: to compare information
CMP_INFO PROC
MOV AL, U_BUF+1
CMP AL, U_LEN
JNE S_NEQ
MOV AL, P_CNT
CMP AL, P_LEN
JNE S_NEQ
MOV SI, OFFSET U_BUF+2
MOV DI, OFFSET U_ORI
CLD
MOV CL, U_LEN
MOV CH, 0
REPE CMPSB
JNE S_NEQ
MOV SI, OFFSET P_BUF
MOV DI, OFFSET P_ORI
CLD
MOV CL, P_LEN
MOV CH, 0
REPE CMPSB
JNE S_NEQ
MOV DX, OFFSET SUC
CALL PRT_STR
RET
S_NEQ:
MOV DX, OFFSET REJ
CALL PRT_STR
RET
CMP_INFO ENDP
\end{lstlisting}

这是输出结果是这样的:

可以看到被 (**) 包裹的内容是以 LaTeX 的形式编译的。

要求:输入用户名(回显)和密码(不回显),判断与预设的用户名和密码是否相同,如果相同则登陆成功,否则登陆失败。

踩坑:程序初始化的时候需要将 ES 的地址赋值给 ES 寄存器,才能正确使用 CMPSB 指令。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
DATAS SEGMENT
U_ORI DB 'root'
U_LEN DB $-U_ORI
P_ORI DB '123456'
P_LEN DB $-P_ORI
U_BUF DB 250 DUP(0)
P_BUF DB 250 DUP(0)
P_CNT DB 0
U_TIP DB 'username: $'
P_TIP DB 'password: $'
SUC DB 'Login successfully, welcome!', 0AH, '$'
REJ DB 'Infomation does not match, you are rejected!', 0AH, '$'
STAR DB '*$'
D_LOG DB 'DEBUG LOG', 0AH, '$'
DATAS ENDS
CODES SEGMENT
ASSUME CS: CODES, DS: DATAS, ES: DATAS
START:
MOV AX, DATAS
MOV DS, AX
MOV ES, AX
MOV DX, OFFSET U_TIP
CALL PRT_STR
CALL U_INP
CALL PRT_CR
MOV DX, OFFSET P_TIP
CALL PRT_STR
CALL P_INP
CALL PRT_CR
CALL CMP_INFO
MOV AH, 4CH
INT 21H
DEBUG_LOG:
MOV AH, 09H
MOV DX, OFFSET D_LOG
INT 21H
MOV AH, 4CH
INT 21H
; function: to print a string
PRT_STR PROC
MOV AH, 09H
INT 21H
RET
PRT_STR ENDP
; function: to print a new line
PRT_CR PROC
MOV AH, 02H
MOV DL, 0AH
INT 21H
RET
PRT_CR ENDP
; function: to input username
U_INP PROC
MOV AL, 200
MOV U_BUF, AL
MOV AH, 0AH
MOV DX, OFFSET U_BUF
INT 21H
RET
U_INP ENDP
; function: to input password
P_INP PROC
MOV BX, OFFSET P_BUF
CIRC:
MOV AH, 07H
INT 21H
CMP AL, 0DH
JZ NEXT
MOV [BX], AL
INC BX
INC P_CNT
MOV DX, OFFSET STAR
CALL PRT_STR
JMP CIRC
NEXT:
RET
P_INP ENDP
; function: to compare information
CMP_INFO PROC
MOV AL, U_BUF+1
CMP AL, U_LEN
JNE S_NEQ
MOV AL, P_CNT
CMP AL, P_LEN
JNE S_NEQ
MOV SI, OFFSET U_BUF+2
MOV DI, OFFSET U_ORI
CLD
MOV CL, U_LEN
MOV CH, 0
REPE CMPSB
JNE S_NEQ
MOV SI, OFFSET P_BUF
MOV DI, OFFSET P_ORI
CLD
MOV CL, P_LEN
MOV CH, 0
REPE CMPSB
JNE S_NEQ
MOV DX, OFFSET SUC
CALL PRT_STR
RET
S_NEQ:
MOV DX, OFFSET REJ
CALL PRT_STR
RET
CMP_INFO ENDP
CODES ENDS
END START

题目链接:Codeforces 1238D

反向考虑,要求好子串的数量,就用总字串的数量 \(n \choose 2\) 里减去坏子串的数量。\(n \choose 2\) 的原因是长度为 \(1\) 的可以不用考虑,它一定是不满足条件的,所以总量是长度大于等于2的所有子串的数量,要减去的是长度大于等于 \(2\) 的所有坏子串的数量。 考虑什么样的串是坏的,对于一个子串 \(t[1\cdots k]\),只要存在一个 \(t[i]\) 不包含于任何一个长度大于二的回文中,那么 \(t[1\cdots k]\) 就是坏的。而实际上,\(t[i]=t[i-1]\)\(t[i]=t[i+1]\) 的时候,显然是存在的,因为至少有一个长度为2的回文,如果 \(t[i] \neq t[i-1]\) 并且 \(t[i] \neq t[i+1]\),那么可以形成长度为 \(3\) 的回文。所以如果存在这样一个 \(t[i]\) 使得 \(t[1\cdots k]\) 是坏的,那么要么 \(i=1\),要么 \(i=k\)

如果我们从 \(t[2]\) 往后能找到一个(最前面一个)\(t[i]=t[1]\),那么说明 \(t[1]\) 一定包含在回文串里,对于 \(t[k]\) 类似。所以坏的串具有这样的形式:

  • ABB...B
  • BAA...A
  • A...AAB
  • B...BBA

所以 \(O(n)\) 统计这些串的个数就可以了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MAX_N = 300000 + 5;
LL ans = 0;
int n;
char s[MAX_N];
int main() {
scanf("%d", &n);
scanf("%s", s + 1);
for (int i = 1; i < n; ++i) {
if (s[i] == s[i + 1]) continue;
int last = i;
for (int j = i + 1; j <= n && s[j] != s[i]; ++j) {
--ans;
last = j - 1;
}
i = last;
}
reverse(s + 1, s + n + 1);
for (int i = 1; i < n; ++i) {
if (s[i] == s[i + 1]) continue;
int last = i;
for (int j = i + 2; j <= n && s[j] != s[i]; ++j) {
--ans;
last = j - 1;
}
i = last;
}
ans += 1LL * n * (n - 1) / 2;
printf("%lld\n", ans);
return 0;
}