问题描述
小 \(D\) 最近学习了最小生成树的有关知识。为了更好地学习求最小生成树的流程,他造了一张 \(n\) 个点的带权无向完全图(即任意两个不同的点之间均有且仅有一条无向边的图),并求出了这个图的最小生成树。
为了简单起见,小 \(D\) 造的无向图的边权为 \([1,\frac{n(n-1)}{2}]\) 之间的整数,且任意两条边的边权均不一样。
若干天后,小 \(D\) 突然发现自己不会求最小生成树了。于是他找出了当时求出的最小生成树,但是原图却怎么也找不到了。于是小 \(D\) 想要求出,有多少种可能的原图。但是小 \(D\) 连最小生成树都不会求了,自然也不会这个问题。请你帮帮他。
形式化地,你会得到 \(n-1\) 个递增的正整数 \(a_1,a_2,\cdots,a_{n-1}\),依次表示最小生成树上的边的边权。你要求出,有多少张 \(n\) 个点的带权无向完全图,满足:
- 每条边的边权为 \([1,\frac{n(n-1)}{2}]\) 之间的整数;
- 任意两条不同的边的边权也不同;
- 至少存在一种最小生成树,满足树上的边权按照从小到大的顺序排列即为 \(a_1,a_2,\cdots,a_{n-1}\)(事实上,可以证明任意一张无向图的任意两棵最小生成树上的边权集合相同)。
因为答案可能很大,所以你只要求出答案对 \(10^9+7=1,000,000,007\)(一个质数)取模的结果即可。
输入格式
第一行一个整数 \(n\)。
第二行 \(n-1\) 个空格隔开的整数 \(a_1,a_2,\cdots,a_{n-1}\),表示最小生成树上的边权。
输出格式
一行一个整数表示可能的无向图个数对 \(10^9+7\) 取模的结果。
样例输入
7
1 2 4 5 7 9
样例输出
616898266
数据范围
\(n \le 40\)
解析
考虑如果DP的话,我们面临的最大问题是如何设计状态来记录当前图的连通块状况。我们可以用一个数组记录当前每个连通块的大小,然后将这个数组哈希成一个整数。可以发现,这样一个数组实际上是一个关于n的整数划分。而n只有40,因此本质上不一样的状态数量(每个数组从小到大排序之后不一样)最多只有40000个。这个性质我们后面会用到。
再考虑如何转移。我们不妨从小到大加入题目所给的树边,设当前加入的边的边权是\(a_i\)。那么,在加入这条边之前,由题目给的条件,我们还需要加入\(num=a_i-a_{i-1}-1\)条非树边。 这些非树边加入时显然不能改变原图连通块的数量,那么我们可以选择任意两个在同一个连通块中的还没有连边的点,数量为:
\[
sum=(\sum_{i=1}^{n} num_i\times \frac{i(i-1)}{2} )-a_{i-1}
\]
即所有可能的边减去已经加入的\(a_{i-1}\)条边。其中\(num_i\)表示大小为\(i\)的连通块个数。如果\(sum\)小于需要加入的非树边的数量,说明当前状态不合法。否则,加入非树边的方案数为:
\[
cnt_1=C_{sum}^{num}\times num!=\frac{sum!}{(sum-num)!}
\]
接下来考虑加入树边。我们选择两个连通块,然后分别在一个块中选择一个点。方案数为:
\[
cnt_2=\sum_{i=1}^n [num_i>1](\frac{num_i(num_i-1)}{2}\times i^2)+\sum_{j=i+1}^nnum_i\times num_j\times i\times j
\]
即仍然以\(num_i\)为基础,特殊讨论选择两个大小相同的连通块的情况。有了上述基础,DP方程就很好想了。设\(f[i][j]\)表示当前加入第\(i\)条边,之后连通块情况的哈希值为\(j\)时的方案数。我们有:
\[
f[i][j]=\sum_{k}f[i-1][k]\times cnt_1\times cnt_2
\]
其中\(k\)满足状态\(k\)加入第\(i-1\)条边之后能够得到状态\(j\)。再利用最开始提到的性质,每次产生的本质不同的状态最多只有40000种,我们只要在每次转移时注意对拓展到的状态去重即可。此外,为了方便,对于每个状态我们可以把它转化为\(num_i\)的形式,即记录每个大小的连通块的数量,方便转移。不难证明,这样的转移不会改变原来状态本质不同的性质。
初始状态是所有连通块大小为1,有n个;最终状态为连通块大小为n,只有1个。
代码
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <map> #include <queue> #define int long long #define N 42 #define M 40002 using namespace std; const int mod=1000000007; struct node{ int cnt,num[N]; inline bool operator <(node B)const{ for(register int i=0;i<=40;++i) if(num[i]!=B.num[i]) return num[i]<B.num[i]; return 0; } inline bool operator ==(node B)const{ for(register int i=0;i<=40;++i) if(num[i]!=B.num[i]) return 0; return 1; } }tmp,p[M]; int n,i,j,k,l,a[N],cntp,x,f[2][M],g[M],fac[N*N],inv[N*N]; bool vis[M]; map<node,int> m; queue<int> q; int read() { char c=getchar(); int w=0; while(c<'0'||c>'9') c=getchar(); while(c<='9'&&c>='0'){ w=w*10+c-'0'; c=getchar(); } return w; } void dfs(int last,int n) { if(n==0){ p[++cntp]=tmp; return; } for(int i=last;i<=n;i++){ tmp.num[++tmp.cnt]=i; dfs(i,n-i); tmp.cnt--; } } int poww(int a,int b) { int ans=1,base=a; while(b){ if(b&1) ans=ans*base%mod; base=base*base%mod; b>>=1; } return ans; } signed main() { int tot=0; n=read(); for(i=1;i<n;i++) a[i]=read(); sort(a+1,a+n); dfs(1,n); fac[0]=1; for(i=1;i<=n*n;i++) fac[i]=fac[i-1]*i%mod; inv[n*n]=poww(fac[n*n],mod-2); for(i=n*n-1;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod; for(i=1;i<=cntp;i++){ int tmp[N]; memset(tmp,0,sizeof(tmp)); for(j=1;j<=p[i].cnt;j++) tmp[p[i].num[j]]++; for(j=1;j<=n;j++) p[i].num[j]=tmp[j]; p[i].cnt=0;m[p[i]]=i; } f[0][1]=1;q.push(1);vis[1]=1; for(i=1;i<n;i++){ x^=1; memset(f[x],0,sizeof(f[x])); int num=a[i]-a[i-1]-1; while(!q.empty()){ g[++l]=q.front(); q.pop(); vis[g[l]]=0; } while(l){ int now=g[l--],sum=1; if(num){ sum=0; for(j=1;j<=n;j++) sum+=(j*(j-1)/2)*p[now].num[j]; sum-=a[i-1]; if(sum>=num) sum=fac[sum]*inv[sum-num]%mod; else continue; } sum=sum*f[x^1][now]%mod; if(sum==0) continue; node P=p[now]; for(j=1;j<=n;j++){ if(!P.num[j]) continue; if(P.num[j]>=2){ int tmp=(P.num[j]*(P.num[j]-1)/2)*j*j; P.num[j]-=2;P.num[j*2]++; int id=m[P]; if(!vis[id]){ q.push(id); vis[id]=1; } f[x][id]=(f[x][id]+tmp*sum%mod)%mod; P.num[j]+=2;P.num[j*2]--; } for(k=j+1;k<=n;k++){ if(!P.num[k]) continue; int tmp=P.num[j]*P.num[k]*j*k; P.num[j]--;P.num[k]--; P.num[j+k]++; int id=m[P]; if(!vis[id]){ q.push(id); vis[id]=1; } f[x][id]=(f[x][id]+tmp*sum%mod)%mod; P.num[j]++;P.num[k]++; P.num[j+k]--; } } } } int ans=f[x][cntp]*fac[n*(n-1)/2-a[n-1]]%mod; printf("%lld\n",ans); return 0; }
来源:https://www.cnblogs.com/LSlzf/p/12315245.html