Gym 101981G - Pyramid

题目链接:Gym - 101981G

这道题是去年南京区预赛的题,现场没有推出公式,到现在仍然不知道公式怎么推,但是最近从其他的推公式的题中得到了一些启发,总结了一些打表找规律的方法。

首先我们可以先建立一个平面直角坐标系,假设边长为\(n\)的三角形中有\(d\)个点,我们把\(d\)个点暴力枚举出来,然后\(O({\rm C}_{n}^{3})\)枚举多少个组合能构成等边三角形。

打表程序如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MAX_N = 1000 + 5;
int n, tot;
double x[MAX_N], y[MAX_N];
inline double dist(int u, int v) {
double dx = x[u] - x[v], dy = y[u] - y[v];
return sqrt(dx * dx + dy * dy);
}
inline bool equal(double u, double v) {
return fabs(u - v) < 0.01;
}
int main() {
scanf("%d", &n);
double nx = 0, ny = 0;
for (int i = 1; i <= n + 1; ++i) {
for (int j = 0; j < n + 2 - i; ++j) {
double ux = nx + j * 2 * sqrt(3);
++tot;
x[tot] = ux;
y[tot] = ny;
}
ny += 3; nx += sqrt(3);
}
for (int i = 1; i <= tot; ++i) {
printf("x = %.2f, y = %.2f\n", x[i], y[i]);
}
int cnt = 0;
for (int i = 1; i <= tot; ++i) {
for (int j = i + 1; j <= tot; ++j) {
for (int k = j + 1; k <= tot; ++k) {
double a = dist(i, j);
double b = dist(i, k);
double c = dist(j, k);
if (equal(a, b) && equal(b, c) && equal(c, a)) ++cnt;
}
}
}
printf("%d\n", cnt);
return 0;
}

然后我们可以根据前几项的结果进行猜想,可以往三个方向考虑:

  • 多次差分之后为等差数列的,通项公式为多项式
  • 多次差分之后增长规律几乎不变并为指数级增长的,可能是齐次线性递推
  • 否则可能是非齐次线性递推

这里我们发现多次差分之后是一个等差数列,考虑待定系数求解多项式的系数,这里可以使用高斯消元的方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class AugMatrix:
def __init__(self, row, col):
self.row, self.col = row, col
self.mtr = [[0] * (col + 1) for i in range(row)]
def fillAt(self, row, col, val):
self.mtr[row][col] = val
def fillAugAt(self, row, val):
self.mtr[row][self.col] = val
def fill(self, mtr, augMtr):
for i in range(self.row):
for j in range(self.col):
self.mtr[i][j] = mtr[i][j]
self.mtr[i][self.col] = augMtr[i]
def printMtr(self):
for i in range(self.row):
print(self.mtr[i])
@staticmethod
def gcd(a, b):
while b != 0:
a, b = b, a % b
return a
def rowMul(self, row, val):
for i in range(self.col + 1):
self.mtr[row][i] = self.mtr[row][i] * val
def rowSub(self, rowA, rowB):
for i in range(self.col + 1):
self.mtr[rowA][i] = self.mtr[rowA][i] - self.mtr[rowB][i]
def eliminate(self):
for i in range(self.col - 1):
for j in range(self.row - 1, i, -1):
now, pre = self.mtr[j][i], self.mtr[j - 1][i]
if now == 0:
continue
elif pre == 0:
self.mtr[j], self.mtr[j - 1] = self.mtr[j - 1], self.mtr[j]
continue
gcd = self.gcd(now, pre)
lcm = now * pre // gcd
self.rowMul(j, lcm // now)
self.rowMul(j - 1, lcm // pre)
self.rowSub(j, j - 1)
for i in range(self.col - 1, 0, -1):
for j in range(0, i):
now, pre = self.mtr[j][i], self.mtr[j + 1][i]
if now == 0 or pre == 0:
continue
gcd = self.gcd(now, pre)
lcm = now * pre // gcd
self.rowMul(j, lcm // now)
self.rowMul(j + 1, lcm // pre)
self.rowSub(j, j + 1)
def approximate(self):
for i in range(self.row):
g = self.gcd(self.mtr[i][self.col], self.mtr[i][i])
self.mtr[i][self.col] = self.mtr[i][self.col] // g
self.mtr[i][i] = self.mtr[i][i] // g
def simplify(self):
for i in range(self.row):
self.mtr[i][self.col] = self.mtr[i][self.col] / self.mtr[i][i]
self.mtr[i][i] = 1
def generate(self):
ret = list()
for i in range(self.row):
ret.append(self.mtr[i][self.col])
return ret
n = 7
x = [1, 2, 3, 4, 5, 6, 7]
y = [1, 5, 15, 35, 70, 126, 210]
m = AugMatrix(n, n)
for i in range(7):
for j in range(7):
m.fillAt(i, j, x[i] ** j)
m.fillAugAt(i, y[i])
m.eliminate()
m.approximate()
m.printMtr()

接着通过前七项解出了系数\((0, \frac{1}{4}, \frac{11}{24}, \frac{1}{4}, \frac{1}{24}, 0, 0, 0)\),得到: \[f(x) = \frac{1}{4} x + \frac{11}{24} x^2 + \frac{1}{4} x^3 + \frac{1}{24} x^4 \] 最后通过多算几项验证正确性后,得到了最终的程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL P = LL(1E9) + 7;
LL inv_4, inv_24;
inline LL getPow(LL x, LL y) {
LL ret = 1;
while (y) {
if (y & 1) ret = ret * x % P;
x = x * x % P;
y >>= 1;
}
return ret;
}
inline LL f(LL x) {
LL ret = 0, base = x;
ret = (ret + inv_4 * base) % P; base = base * x % P;
ret = (ret + 11 * inv_24 * base) % P; base = base * x % P;
ret = (ret + inv_4 * base) % P; base = base * x % P;
ret = (ret + inv_24 * base) % P;
return ret;
}
int main() {
inv_4 = getPow(4, P - 2);
inv_24 = getPow(24, P - 2);
int T, n;
scanf("%d", &T);
for (int cs = 1; cs <= T; ++cs) {
scanf("%d", &n);
printf("%lld\n", f(n));
}
return 0;
}