Description
计算 \(1...n\) 的全排列中不同的前缀最大值的个数等于 \(a\) 且不同的后缀最大值的个数等于 \(b\) 的排列的数目对 \(998244353\) 取模的结果
\(1\le n\le 10^5,0\le a,b\le n\)
Solution
画一画合法的排列可以分析得出以下信息
无论是前缀最大值还是后缀最大值,新出现的最大值 \(k\) 总是会以 \(\{k,i_1,i_2,...,i_c\}\) 的形式出现,其中 \(i_1,i_2,...,i_c\) 为任意 \(c\) 个小于 \(k\) 的数,\(c\) 可以为 \(0\)
整个排列的最大值,即 \(n\),会将这个序列分为前后两个排列,而这两个排列是独立的
由于我们只关注排列的相对大小关系,所以被处在下标 \(i\) 位置的 \(n\) 划分出的两个小排列可以直接看作 \(\{1,2,...,i-1\}\) 与 \(\{1,2,...,n-i\}\) 两组元素的排列
所以我们只需要求 \(1...i\) 的全排列中不同的前缀最大值的个数等于 \(s\) 的排列个数即可,\(i\in [0,n],s\in[1,n]\)
从第一条性质我们可以看出,如果我们将 \(n\) 个元素划分为 \(j\) 个集合,并默认其中的最大值为 \(k\) ,并按每个集合的 \(k\) 将这 \(j\) 个集合从小到大排序,那么这一定是一种符合条件的方案,并且这种构造方法可以包括所有的合法方案
但是如果这样用第二类斯特林数做的话,对于每一个大小为 \(s\) 的集合,我们还需乘上 \((s-1)!\),这样就不太好做了
实际上第一类斯特林数可以很好地解决这个问题,因为它本质上是枚举将长度为 \(n\) 的序列分解为 \(i\) 个循环的方案数,符合我们的要求
所以我们最终要求的是 \(\begin{bmatrix}i\\a-1\end{bmatrix}\) 与 \(\begin{bmatrix}i\\b-1\end{bmatrix}\),其中 \(i\in [0,n-1]\)
答案就是 \(\text{ans}=\sum\limits_{i=1}^{n}\begin{bmatrix}i-1\\a-1\end{bmatrix}\begin{bmatrix}n-i\\b-1\end{bmatrix}\)
那么怎么快速求第一类斯特林数呢?
我们知道 \(\begin{bmatrix}n\\i\end{bmatrix}\) 的生成函数是 \(x^{\overline{n}}\),但这只能快速求一行的第一类斯特林数,对于求一列的话,复杂度就退化成 \(O(n^2\log n)\) 了
然后发现没有快速求一列第一类斯特林数的方法。。。
从另一个角度思考问题,不如忽略掉第三条性质,直接生成 \(a+b-2\) 个循环再分配到两边,这样也是对的
那么答案就是 \(\text{ans}=\begin{bmatrix}n-1\\a+b-2\end{bmatrix}\dbinom{a+b-2}{a-1}\)
组合数很好求,下面具体说说如何求第一类斯特林数
因为 \(\begin{bmatrix}n\\i\end{bmatrix}=[x^i](x(x+1)(x+2)...(x+n-1))\) ,所以关键是如何求这 \(n\) 个多项式的卷积
考虑倍增
假设现在已经求出了 \(x(x+1)(x+2)...(x+k-1)\) ,需要求 \(x(x+1)(x+2)...(x+k-1)(x+k)(x+k+1)...(x+2k-1)\)
显然后一半的式子可以通过将 \(x+k\) 带入前一半的式子得到,然后再把两个式子 \(\text{NTT}\) 一下就可以得到我们想要的式子了
怎么带入呢?考虑二项式定理,那么对于 \(p\) 次项,它的系数为
\[
\sum\limits_{i=p}^{k}a_i\dbinom{i}{p}k^{i-p}\tag{1}
\]
其中 \(a_i\) 为带入前第 \(i\) 次项的系数
展开这个式子,得到
\[
\frac{1}{p!}\sum\limits_{i=p}^{k}a_ii!\frac{k^{i-p}}{(i-p)!}\tag{2}
\]
定义 \(f_i=a_ii!,g_i=\frac{k^i}{i!}\),那么带入后第 \(p\) 次项系数 \(a'_p\) 为
\[
a'_p=\frac{1}{p!}\sum\limits_{i=p}^{k}f_ig_{i-p}\tag{3}
\]
将 \(g_i\) 中的值全部翻转,得 \(g'_i\),那么
\[
a'_p=\frac{1}{p!}\sum\limits_{i=0}^{k-p}f_{k-i}g'_{p+i}\tag{4}
\]
也可以写成
\[
a'_p=\frac{1}{p!}\sum\limits_{i=0}^{k+p}f_ig'_{k+p-i}\tag{5}
\]
然后就能 \(\text{NTT}\) 求出系数了
复杂度 \(O(n\log n)\)
代码如下:
#include<cstdio> #include<iostream> #include<algorithm> using namespace std; const int N=1e5+10; const int mod=998244353; const int G=3; const int invG=332748118; int n,A,B,fac[N<<1],inv[N<<1],f[N<<2],g[N<<2],a[N<<2],b[N<<2],k,now,INV; inline void Preprocess(){ fac[0]=1;for(register int i=1;i<=(n<<1);i++)fac[i]=1ll*fac[i-1]*i%mod; inv[0]=inv[1]=1;for(register int i=2;i<=(n<<1);i++)inv[i]=(-1ll*mod/i*inv[mod%i]%mod+mod)%mod; for(register int i=2;i<=(n<<1);i++)inv[i]=1ll*inv[i-1]*inv[i]%mod; } inline int C(int n,int m){if(n<0||m<0||n<m)return 0;return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;} inline int fas(int x,int p){int res=1;while(p){if(p&1)res=1ll*res*x%mod;p>>=1;x=1ll*x*x%mod;}return res;} inline int MOD(int x){x-=x>=mod? mod:0;return x;} inline void NTT(int *a,int f){ for(register int i=0,j=0;i<k;i++){ if(i>j)swap(a[i],a[j]); for(register int l=k>>1;(j^=l)<l;l>>=1);} for(register int i=1;i<k;i<<=1){ int w=fas(~f? G:invG,(mod-1)/(i<<1)); for(register int j=0;j<k;j+=(i<<1)){ int e=1; for(register int p=0;p<i;p++,e=1ll*e*w%mod){ int x=a[j+p],y=1ll*a[j+p+i]*e%mod; a[j+p]=MOD(x+y);a[j+p+i]=MOD(x-y+mod); } } } } inline void Solve(int m){ if(m==1){f[1]=1;return;} int M=m>>1;Solve(M); for(register int i=0;i<=M;i++)a[i]=1ll*f[i]*fac[i]%mod; now=1; for(register int i=0;i<=M;i++) b[i]=1ll*now*inv[i]%mod,now=1ll*now*M%mod; reverse(b,b+M+1); k=1;while(k<=M+M)k<<=1;INV=fas(k,mod-2); for(register int i=M+1;i<k;i++)a[i]=b[i]=0; NTT(a,1);NTT(b,1); for(register int i=0;i<k;i++)a[i]=1ll*a[i]*b[i]%mod; NTT(a,-1); for(register int i=0;i<k;i++)a[i]=1ll*a[i]*INV%mod; for(register int i=0;i<=M;i++)g[i]=1ll*inv[i]*a[M+i]%mod; for(register int i=M+1;i<k;i++)g[i]=0; NTT(f,1);NTT(g,1); for(register int i=0;i<k;i++)f[i]=1ll*f[i]*g[i]%mod; NTT(f,-1); for(register int i=0;i<=(M<<1);i++)f[i]=1ll*f[i]*INV%mod; if((M<<1)!=m){ for(register int i=m;i;i--) f[i]=MOD(f[i-1]+1ll*(m-1)*f[i]%mod); f[0]=1ll*f[0]*(m-1)%mod; } } int main(){ scanf("%d%d%d",&n,&A,&B); if(A+B-2<0||A-1<0||B-1<0||A+B-2>n-1){puts("0");return 0;} Preprocess(); if(n==1){if(A+B-2==0)printf("%d\n",C(A+B-2,A-1));else puts("0");return 0;} Solve(n-1); printf("%lld\n",1ll*f[A+B-2]*C(A+B-2,A-1)%mod); return 0; }