题目大意

玩一场游戏,每一轮在 [1,n][1,n] 中随机生成一个整数 xx,生成 xx 的概率为 pxp_x。若 xx 不小于之前生成的任何数,则游戏继续并进入下一轮,否则游戏结束。假如说最后游戏共进行了 lenlen 轮,则游戏的得分为 len2len^2。求游戏分数的期望值 E(len2)E(len^2),答案(mod998244353)\pmod{998244353} 输出。

输入为 nn 个数 w1,w2,,wnw_1,w_2,\cdots,w_npi=wij=1nwjp_i=\frac{w_i}{\sum\limits_{j=1}^n w_j}

思路

游戏生成的数字序列一定是类似 1,1,,1,2,2,3,31,1,\cdots,1,2,2\cdots,3,3\cdots,即 [1,n][1,n] 的数字由小到大依次产生,每个数字出现的次数为 [0,+)[0,+\infty)

构造母函数 f(x)=i=1nm=0pimxm (1)f(x)=\prod\limits_{i=1}^n\sum\limits_{m=0}^\infty p_i^mx^m\ (1),对于每一项 m=0pimxm\sum\limits_{m=0}^\infty p_i^mx^m,指数 mm 表示数字 ii 出现的次数,pimp_i^m 就是 ii 连续出现 mm 次的概率。每一个数的出现次数都对应多项式 m=0pimxm\sum\limits_{m=0}^\infty p_i^mx^m,将这 nn 个多项式进行卷积运算,即为 f(x)f(x)。卷积的结果为 f(x)=m=0amxmf(x)=\sum\limits_{m=0}^\infty a_mx^mmm 为生成序列的长度,系数 ama_m 自然是生成长度为 mm 的非递减序列的概率。仔细思考,ama_m 同时也是游戏进行轮次超过 mm 次的概率,即 f(x)=m=0P(len>m)xm (2)f(x)=\sum\limits_{m=0}^\infty P(len>m)x^m\ (2);因为游戏已经进行了 mm 轮,且可以继续进行,游戏要进行超过 mm 轮等价于生成长度为 mm 的非递减序列。

根据数学期望的定义,枚举最后游戏的轮次 mm

$$ \begin{aligned} E(len^2)&=\sum\limits_{m=1}^\infty [P(len>m-1)-P(len>m)]m^2\\&=P(len>0)+\sum\limits_{m=1}^\infty P(len>m)[(m+1)^2-m^2]\\&=1+\sum\limits_{m=1}^\infty P(len>m)(2i+1)\\&=\sum\limits_{m=0}^\infty P(len>m)(2i+1) \end{aligned} $$

根据 (2)(2) 可以发现,E(len2)=2f(1)+f(1)E(len^2)=2f'(1)+f(1)。我们需要用 (1)(1) 来求 f(1)f'(1)f(1)f(1)。首先将 f(x)f(x) 化为其封闭形式,即 f(x)=i=1n11pix (3)f(x)=\prod\limits_{i=1}^n\frac{1}{1-p_ix}\ (3)。将 x=1x=1 带入 (3)(3),容易得到 f(1)=i=1n11pif(1)=\prod\limits_{i=1}^n \frac{1}{1-p_i}。将 (3)(3) 两边取自然对数,lnf(x)=i=1nln11pix\ln f(x)=\sum\limits_{i=1}^n\ln\frac{1}{1-p_ix};两边对 xx 求导,f(x)f(x)=i=1npi1pix\frac{f'(x)}{f(x)}=\sum\limits_{i=1}^n\frac{p_i}{1-p_ix};将 x=1x=1 带入,得 f(1)=f(1)i=1npi1pif'(1)=f(1)\sum\limits_{i=1}^n\frac{p_i}{1-p_i}。将 pi=wij=1nwjp_i=\frac{w_i}{\sum\limits_{j=1}^n w_j} 代入,并记 sum=j=1nwjsum=\sum\limits_{j=1}^n w_j,有 f(1)=i=1nsumsumwif(1)=\prod\limits_{i=1}^n\frac{sum}{sum-w_i}f(1)=f(1)i=1nwisumwif'(1)=f(1)\sum\limits_{i=1}^n\frac{w_i}{sum-w_i}

f(1)f(1)f(1)f'(1),时间复杂度 O(n)O(n),最后 O(1)O(1) 代入求 E(len2)E(len^2)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include<iostream>
using namespace std;
typedef long long ll;
const int MOD = 998244353;
const int MAX_SIZE = 105;
int n, w[MAX_SIZE];
ll QuickPower(ll a, ll b)
{
ll res = 1;
while (b)
{
if (b & 1)
{
res = res * a % MOD;
}
a = a * a % MOD;
b >>= 1;
}
return res;
}
int main()
{
cin >> n;
int sum = 0;
for (int i = 1;i <= n;i++)
{
cin >> w[i];
sum += w[i];
}
// 求 f(1)
int a = 1;
for (int i = 1;i <= n;i++)
{
a = 1ll * a * sum % MOD * QuickPower(sum - w[i], MOD - 2) % MOD;
}
// 求 f'(1)
int b = 0;
for (int i = 1;i <= n;i++)
{
b = (b + 1ll * w[i] * QuickPower(sum - w[i], MOD - 2) % MOD) % MOD;
}
b = 1ll * b * a % MOD;
// 输出答案
cout << (2ll * b % MOD + a) % MOD << endl;
return 0;
}