AT_agc038_e 题解
容斥好题,考察了 min-max 容斥,期望的线性性,dp状态的设计。
不妨假设第 个元素在第 时刻出现了 次,设 ,则我们想要求的即为 的期望。
可以套用 min-max 在期望意义下的容斥公式,即
那么我们的任务就变为要如何求出 。
考虑对于某个样本 ,若假设无论如何一定选择在 内,在到达 的过程当中有若干中间过程 。如 ,则 为一种合法的中间状态。不妨设中间状态为 ,则我们有一个结论可以解决这个问题:
即我们所求期望即为中间状态的概率和,这要怎么证明呢?
📝 NOTE 首先,我们有
注意此时为选择前。
不难发现其等价于 。
那么给等式两边同时加上期望,即
接下来我们思考如何求出 。假如此时 。将 视为一个多重集,元素 有 个,则到达该状态的路径数为多重集的排列数,即
而通过任意一条路径的概率即为路径序列上各点的概率乘积,由乘法交换律,知其为
那么到达状态的概率即为路径条数乘以每条路径的概率,即
则在我们的假设条件下(所有选择均选择在 内),有
当然,我们的所有选择并不会全在 内,单次选择在 内的概率为 ,设 ,则选择到 的期望次数为 。
所以在没有限制条件的情况下
带入 min-max 容斥公式,提出无关变量,有
对于这个式子而言,我们发现其做乘法的部分较多,若只是单纯的乘法,便也许可以使用 dp 进行转移。但是由于其中存在着多处求和部分,其中 为定值,在剩下的求和中, 和 在不断变化。所以我们或许可以考虑将这两个变化的量设计进入 dp 状态中。
故不妨设 dp[i][j][k]dp[i][j][k] 表示当前枚举到第 ii 位,
,
,则可以设
📝 NOTE 为什么要这样设计状态?
不难发现我们的状态内的值都不直接与 有关,这是因为如果有关,那么在 dp 时要对状态中的分母等信息进行修改,这较我们上面所设计的状态而言更加难以完成。
转移时考虑
是否在我们选出的子集
内,若不在,则显然 dp[i][j][k]+=dp[i-1][j][k]dp[i][j][k]+=dp[i-1][j][k]。若在,则枚举
,dp[i][j][k]+= -(dp[i-1][j-A[i]][k-c[i]]*(pow(A[i],c[i])/fact(c[i])))dp[i][j][k]+= -(dp[i-1][j-A[i]][k-c[i]]*(pow(A[i],c[i])/fact(c[i])))(每次集合中多增加一个数,(
) 的符号取反,故要加一个负号)。注意 dp[0][0][0]=-1dp[0][0][0]=-1(因为此时集合中无元素,参考上述状态知
)。
📝 NOTE 时间复杂度说明
注意到若 在集合里时,转移最多进行 次,而不在时,时间复杂度显然为 。故总转移复杂度为 。不过若幂部分的实现不好,最终结果可能会多一个 。
求答案时参考答案式子,枚举
,则最终答案 ans+=dp[n][j][k]*s/j*fact(k)*pow((1/j),k);ans+=dp[n][j][k]*s/j*fact(k)*pow((1/j),k);。
代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
int const MOD=998244353;
int const MAX=405;
int dp[MAX][MAX][MAX];
int a[MAX],b[MAX];
int fact[MAX],inv[MAX],finv[MAX];
int qpow(int t1,int t2){
int res=1;
while(t2){
if(t2&1){
res*=t1;
res%=MOD;
}
t1=t1*t1%MOD;
t2>>=1;
}
return res;
}
signed main(){
fact[0]=1;
for(int i=1;i<MAX;i++)fact[i]=fact[i-1]*i%MOD;
finv[MAX-1]=qpow(fact[MAX-1],MOD-2);
for(int i=MAX-2;i>=0;i--){
finv[i]=finv[i+1]*(i+1)%MOD;
if(i)inv[i]=finv[i]*fact[i-1]%MOD;
}
int n;
cin>>n;
int sa=0,sb=0;
for(int i=1;i<=n;i++)cin>>a[i]>>b[i],sa+=a[i],sb+=b[i];
dp[0][0][0]=-1;
for(int i=1;i<=n;i++){
for(int j=0;j<=sa;j++){
for(int k=0;k<=sb;k++){
dp[i][j][k]=dp[i-1][j][k];
for(int l=0;l<b[i];l++){
if(j>=a[i]&&k>=l)dp[i][j][k]-=(dp[i-1][j-a[i]][k-l]*qpow(a[i],l)%MOD*finv[l]%MOD)%MOD;
dp[i][j][k]=(dp[i][j][k]+MOD)%MOD;
}
}
}
}
int ans=0;
for(int j=0;j<=sa;j++){
for(int k=0;k<=sb;k++){
ans+=(dp[n][j][k]*sa%MOD*inv[j]%MOD*fact[k]%MOD*qpow(inv[j],k))%MOD;
ans%=MOD;
}
}
cout<<ans;
return 0;
}
#include<bits/stdc++.h>
using namespace std;
#define int long long
int const MOD=998244353;
int const MAX=405;
int dp[MAX][MAX][MAX];
int a[MAX],b[MAX];
int fact[MAX],inv[MAX],finv[MAX];
int qpow(int t1,int t2){
int res=1;
while(t2){
if(t2&1){
res*=t1;
res%=MOD;
}
t1=t1*t1%MOD;
t2>>=1;
}
return res;
}
signed main(){
fact[0]=1;
for(int i=1;i<MAX;i++)fact[i]=fact[i-1]*i%MOD;
finv[MAX-1]=qpow(fact[MAX-1],MOD-2);
for(int i=MAX-2;i>=0;i--){
finv[i]=finv[i+1]*(i+1)%MOD;
if(i)inv[i]=finv[i]*fact[i-1]%MOD;
}
int n;
cin>>n;
int sa=0,sb=0;
for(int i=1;i<=n;i++)cin>>a[i]>>b[i],sa+=a[i],sb+=b[i];
dp[0][0][0]=-1;
for(int i=1;i<=n;i++){
for(int j=0;j<=sa;j++){
for(int k=0;k<=sb;k++){
dp[i][j][k]=dp[i-1][j][k];
for(int l=0;l<b[i];l++){
if(j>=a[i]&&k>=l)dp[i][j][k]-=(dp[i-1][j-a[i]][k-l]*qpow(a[i],l)%MOD*finv[l]%MOD)%MOD;
dp[i][j][k]=(dp[i][j][k]+MOD)%MOD;
}
}
}
}
int ans=0;
for(int j=0;j<=sa;j++){
for(int k=0;k<=sb;k++){
ans+=(dp[n][j][k]*sa%MOD*inv[j]%MOD*fact[k]%MOD*qpow(inv[j],k))%MOD;
ans%=MOD;
}
}
cout<<ans;
return 0;
}
#include<bits/stdc++.h>
using namespace std;
#define int long long
int const MOD=998244353;
int const MAX=405;
int dp[MAX][MAX][MAX];
int a[MAX],b[MAX];
int fact[MAX],inv[MAX],finv[MAX];
int qpow(int t1,int t2){
int res=1;
while(t2){
if(t2&1){
res*=t1;
res%=MOD;
}
t1=t1*t1%MOD;
t2>>=1;
}
return res;
}
signed main(){
fact[0]=1;
for(int i=1;i<MAX;i++)fact[i]=fact[i-1]*i%MOD;
finv[MAX-1]=qpow(fact[MAX-1],MOD-2);
for(int i=MAX-2;i>=0;i--){
finv[i]=finv[i+1]*(i+1)%MOD;
if(i)inv[i]=finv[i]*fact[i-1]%MOD;
}
int n;
cin>>n;
int sa=0,sb=0;
for(int i=1;i<=n;i++)cin>>a[i]>>b[i],sa+=a[i],sb+=b[i];
dp[0][0][0]=-1;
for(int i=1;i<=n;i++){
for(int j=0;j<=sa;j++){
for(int k=0;k<=sb;k++){
dp[i][j][k]=dp[i-1][j][k];
for(int l=0;l<b[i];l++){
if(j>=a[i]&&k>=l)dp[i][j][k]-=(dp[i-1][j-a[i]][k-l]*qpow(a[i],l)%MOD*finv[l]%MOD)%MOD;
dp[i][j][k]=(dp[i][j][k]+MOD)%MOD;
}
}
}
}
int ans=0;
for(int j=0;j<=sa;j++){
for(int k=0;k<=sb;k++){
ans+=(dp[n][j][k]*sa%MOD*inv[j]%MOD*fact[k]%MOD*qpow(inv[j],k))%MOD;
ans%=MOD;
}
}
cout<<ans;
return 0;
}
#include<bits/stdc++.h>
using namespace std;
#define int long long
int const MOD=998244353;
int const MAX=405;
int dp[MAX][MAX][MAX];
int a[MAX],b[MAX];
int fact[MAX],inv[MAX],finv[MAX];
int qpow(int t1,int t2){
int res=1;
while(t2){
if(t2&1){
res*=t1;
res%=MOD;
}
t1=t1*t1%MOD;
t2>>=1;
}
return res;
}
signed main(){
fact[0]=1;
for(int i=1;i<MAX;i++)fact[i]=fact[i-1]*i%MOD;
finv[MAX-1]=qpow(fact[MAX-1],MOD-2);
for(int i=MAX-2;i>=0;i--){
finv[i]=finv[i+1]*(i+1)%MOD;
if(i)inv[i]=finv[i]*fact[i-1]%MOD;
}
int n;
cin>>n;
int sa=0,sb=0;
for(int i=1;i<=n;i++)cin>>a[i]>>b[i],sa+=a[i],sb+=b[i];
dp[0][0][0]=-1;
for(int i=1;i<=n;i++){
for(int j=0;j<=sa;j++){
for(int k=0;k<=sb;k++){
dp[i][j][k]=dp[i-1][j][k];
for(int l=0;l<b[i];l++){
if(j>=a[i]&&k>=l)dp[i][j][k]-=(dp[i-1][j-a[i]][k-l]*qpow(a[i],l)%MOD*finv[l]%MOD)%MOD;
dp[i][j][k]=(dp[i][j][k]+MOD)%MOD;
}
}
}
}
int ans=0;
for(int j=0;j<=sa;j++){
for(int k=0;k<=sb;k++){
ans+=(dp[n][j][k]*sa%MOD*inv[j]%MOD*fact[k]%MOD*qpow(inv[j],k))%MOD;
ans%=MOD;
}
}
cout<<ans;
return 0;
}