[noi.ac] #31 最小生成树

你。 提交于 2020-02-15 23:51:45

问题描述

\(D\) 最近学习了最小生成树的有关知识。为了更好地学习求最小生成树的流程,他造了一张 \(n\) 个点的带权无向完全图(即任意两个不同的点之间均有且仅有一条无向边的图),并求出了这个图的最小生成树。

为了简单起见,小 \(D\) 造的无向图的边权为 \([1,\frac{n(n-1)}{2}]\) 之间的整数,且任意两条边的边权均不一样。

若干天后,小 \(D\) 突然发现自己不会求最小生成树了。于是他找出了当时求出的最小生成树,但是原图却怎么也找不到了。于是小 \(D\) 想要求出,有多少种可能的原图。但是小 \(D\) 连最小生成树都不会求了,自然也不会这个问题。请你帮帮他。

形式化地,你会得到 \(n-1\) 个递增的正整数 \(a_1,a_2,\cdots,a_{n-1}\),依次表示最小生成树上的边的边权。你要求出,有多少张 \(n\) 个点的带权无向完全图,满足:

  • 每条边的边权为 \([1,\frac{n(n-1)}{2}]\) 之间的整数;
  • 任意两条不同的边的边权也不同;
  • 至少存在一种最小生成树,满足树上的边权按照从小到大的顺序排列即为 \(a_1,a_2,\cdots,a_{n-1}\)(事实上,可以证明任意一张无向图的任意两棵最小生成树上的边权集合相同)。

因为答案可能很大,所以你只要求出答案对 \(10^9+7=1,000,000,007\)(一个质数)取模的结果即可。

输入格式

第一行一个整数 \(n\)

第二行 \(n-1\) 个空格隔开的整数 \(a_1,a_2,\cdots,a_{n-1}\),表示最小生成树上的边权。

输出格式

一行一个整数表示可能的无向图个数对 \(10^9+7\) 取模的结果。

样例输入

7
1 2 4 5 7 9

样例输出

616898266

数据范围

\(n \le 40\)

解析

考虑如果DP的话,我们面临的最大问题是如何设计状态来记录当前图的连通块状况。我们可以用一个数组记录当前每个连通块的大小,然后将这个数组哈希成一个整数。可以发现,这样一个数组实际上是一个关于n的整数划分。而n只有40,因此本质上不一样的状态数量(每个数组从小到大排序之后不一样)最多只有40000个。这个性质我们后面会用到。

再考虑如何转移。我们不妨从小到大加入题目所给的树边,设当前加入的边的边权是\(a_i\)。那么,在加入这条边之前,由题目给的条件,我们还需要加入\(num=a_i-a_{i-1}-1\)条非树边。 这些非树边加入时显然不能改变原图连通块的数量,那么我们可以选择任意两个在同一个连通块中的还没有连边的点,数量为:
\[ sum=(\sum_{i=1}^{n} num_i\times \frac{i(i-1)}{2} )-a_{i-1} \]
即所有可能的边减去已经加入的\(a_{i-1}\)条边。其中\(num_i\)表示大小为\(i\)的连通块个数。如果\(sum\)小于需要加入的非树边的数量,说明当前状态不合法。否则,加入非树边的方案数为:
\[ cnt_1=C_{sum}^{num}\times num!=\frac{sum!}{(sum-num)!} \]
接下来考虑加入树边。我们选择两个连通块,然后分别在一个块中选择一个点。方案数为:
\[ cnt_2=\sum_{i=1}^n [num_i>1](\frac{num_i(num_i-1)}{2}\times i^2)+\sum_{j=i+1}^nnum_i\times num_j\times i\times j \]
即仍然以\(num_i\)为基础,特殊讨论选择两个大小相同的连通块的情况。有了上述基础,DP方程就很好想了。设\(f[i][j]\)表示当前加入第\(i\)条边,之后连通块情况的哈希值为\(j\)时的方案数。我们有:
\[ f[i][j]=\sum_{k}f[i-1][k]\times cnt_1\times cnt_2 \]
其中\(k\)满足状态\(k\)加入第\(i-1\)条边之后能够得到状态\(j\)。再利用最开始提到的性质,每次产生的本质不同的状态最多只有40000种,我们只要在每次转移时注意对拓展到的状态去重即可。此外,为了方便,对于每个状态我们可以把它转化为\(num_i\)的形式,即记录每个大小的连通块的数量,方便转移。不难证明,这样的转移不会改变原来状态本质不同的性质。

初始状态是所有连通块大小为1,有n个;最终状态为连通块大小为n,只有1个。

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#include <queue>
#define int long long
#define N 42
#define M 40002
using namespace std;
const int mod=1000000007;
struct node{
    int cnt,num[N];
    inline bool operator <(node B)const{
        for(register int i=0;i<=40;++i) if(num[i]!=B.num[i]) return num[i]<B.num[i];
        return 0;
    }
    inline bool operator ==(node B)const{
        for(register int i=0;i<=40;++i) if(num[i]!=B.num[i]) return 0;
        return 1;
    }
}tmp,p[M];
int n,i,j,k,l,a[N],cntp,x,f[2][M],g[M],fac[N*N],inv[N*N];
bool vis[M];
map<node,int> m;
queue<int> q;
int read()
{
    char c=getchar();
    int w=0;
    while(c<'0'||c>'9') c=getchar();
    while(c<='9'&&c>='0'){
        w=w*10+c-'0';
        c=getchar();
    }
    return w;
}
void dfs(int last,int n)
{
    if(n==0){
        p[++cntp]=tmp;
        return;
    }
    for(int i=last;i<=n;i++){
        tmp.num[++tmp.cnt]=i;
        dfs(i,n-i);
        tmp.cnt--;
    }
}
int poww(int a,int b)
{
    int ans=1,base=a;
    while(b){
        if(b&1) ans=ans*base%mod;
        base=base*base%mod;
        b>>=1;
    }
    return ans;
}
signed main()
{
    int tot=0;
    n=read();
    for(i=1;i<n;i++) a[i]=read();
    sort(a+1,a+n);
    dfs(1,n);
    fac[0]=1;
    for(i=1;i<=n*n;i++) fac[i]=fac[i-1]*i%mod;
    inv[n*n]=poww(fac[n*n],mod-2);
    for(i=n*n-1;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod;
    for(i=1;i<=cntp;i++){
        int tmp[N];
        memset(tmp,0,sizeof(tmp));
        for(j=1;j<=p[i].cnt;j++) tmp[p[i].num[j]]++;
        for(j=1;j<=n;j++) p[i].num[j]=tmp[j];
        p[i].cnt=0;m[p[i]]=i;
    }
    f[0][1]=1;q.push(1);vis[1]=1;
    for(i=1;i<n;i++){
        x^=1;
        memset(f[x],0,sizeof(f[x]));
        int num=a[i]-a[i-1]-1;
        while(!q.empty()){
            g[++l]=q.front();
            q.pop();
            vis[g[l]]=0;
        }
        while(l){
            int now=g[l--],sum=1;
            if(num){
                sum=0;
                for(j=1;j<=n;j++) sum+=(j*(j-1)/2)*p[now].num[j];
                sum-=a[i-1];
                if(sum>=num) sum=fac[sum]*inv[sum-num]%mod;
                else continue;
            }
            sum=sum*f[x^1][now]%mod;
            if(sum==0) continue;
            node P=p[now];
            for(j=1;j<=n;j++){
                if(!P.num[j]) continue;
                if(P.num[j]>=2){
                    int tmp=(P.num[j]*(P.num[j]-1)/2)*j*j;
                    P.num[j]-=2;P.num[j*2]++;
                    int id=m[P];
                    if(!vis[id]){
                        q.push(id);
                        vis[id]=1;
                    }
                    f[x][id]=(f[x][id]+tmp*sum%mod)%mod;
                    P.num[j]+=2;P.num[j*2]--;
                }
                for(k=j+1;k<=n;k++){
                    if(!P.num[k]) continue;
                    int tmp=P.num[j]*P.num[k]*j*k;
                    P.num[j]--;P.num[k]--;
                    P.num[j+k]++;
                    int id=m[P];
                    if(!vis[id]){
                        q.push(id);
                        vis[id]=1;
                    }
                    f[x][id]=(f[x][id]+tmp*sum%mod)%mod;
                    P.num[j]++;P.num[k]++;
                    P.num[j+k]--;
                }
            }
        }
    }
    int ans=f[x][cntp]*fac[n*(n-1)/2-a[n-1]]%mod;
    printf("%lld\n",ans);
    return 0;
}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!