PAT甲级真题1151 - LCA in a Binary Tree

题目链接:PAT – Advanced Level – 1151

通过先序和中序序列建树,然后倍增求LCA,渐进时间复杂度 \(O((n+m) \log n)\)

坑点:要离散化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 10000 + 5;
int n, m;
vector <int> g[MAX_N];
int iOrd[MAX_N], pOrd[MAX_N];
int f[MAX_N][20], d[MAX_N];
int lg[MAX_N];
void build(int iL, int iR, int pL, int pR) {
int pos = -1, base = pOrd[pL];
for (int i = iL; i <= iR; ++i) {
if (iOrd[i] == base) {
pos = i;
break;
}
}
int iLeftL = iL, iLeftR = pos - 1, iRightL = pos + 1, iRightR = iR;
int
pLeftL = pL + 1,
pLeftR = pL + iLeftR - iLeftL + 1,
pRightL = pL + iLeftR - iLeftL + 2,
pRightR = pR;
if (iLeftL <= iLeftR) {
g[base].push_back(pOrd[pLeftL]);
build(iLeftL, iLeftR, pLeftL, pLeftR);
}
if (iRightL <= iRightR) {
g[base].push_back(pOrd[pRightL]);
build(iRightL, iRightR, pRightL, pRightR);
}
}
set <int> nodeSet;
inline bool testError(int u, int v) {
auto ok = [] (const int &x) -> bool { return nodeSet.find(x) != nodeSet.end(); };
if (!ok(u) && ok(v)) {
printf("ERROR: %d is not found.\n", u);
} else if (ok(u) && !ok(v)) {
printf("ERROR: %d is not found.\n", v);
} else if (!ok(u) && !ok(v)) {
printf("ERROR: %d and %d are not found.\n", u, v);
} else {
return false;
}
return true;
}
void dfs(int x, int fa) {
f[x][0] = fa; d[x] = d[fa] + 1;
for (int i = 1; i <= lg[d[x]]; ++i) {
f[x][i] = f[f[x][i - 1]][i - 1];
}
for (auto &e: g[x]) {
if (e != fa) dfs(e, x);
}
}
int getLCA(int u, int v) {
if (d[u] < d[v]) swap(u, v);
while (d[u] > d[v]) u = f[u][lg[d[u] - d[v]] - 1];
if (u == v) return u;
for (int k = lg[d[u]] - 1; k >= 0; --k) {
if (f[u][k] != f[v][k]) {
u = f[u][k];
v = f[v][k];
}
}
return f[u][0];
}
vector <int> disc;
inline int getId(int x) {
return lower_bound(disc.begin(), disc.end(), x) - disc.begin() + 1;
}
int main() {
scanf("%d%d", &m, &n);
for (int i = 1; i <= n; ++i) scanf("%d", iOrd + i), disc.push_back(iOrd[i]), nodeSet.insert(iOrd[i]);
for (int i = 1; i <= n; ++i) scanf("%d", pOrd + i);
sort(disc.begin(), disc.end());
for (int i = 1; i <= n; ++i) iOrd[i] = getId(iOrd[i]);
for (int i = 1; i <= n; ++i) pOrd[i] = getId(pOrd[i]);
build(1, n, 1, n);
for (int i = 1; i <= n; ++i) lg[i] = lg[i - 1] + ((1 << lg[i - 1]) == i);
dfs(pOrd[1], 0);
for (int i = 1; i <= m; ++i) {
int u, v;
scanf("%d%d", &u, &v);
int ru = getId(u), rv = getId(v);
if (testError(u, v)) continue;
int queryLCA = disc[getLCA(ru, rv) - 1];
if (queryLCA == u) printf("%d is an ancestor of %d.\n", u, v);
else if (queryLCA == v) printf("%d is an ancestor of %d.\n", v, u);
else printf("LCA of %d and %d is %d.\n", u, v, queryLCA);
}
return 0;
}