一个自然的想法是在一个点集里选出一个特定的点,在该点处计入点集贡献.由于点集中所有两点间路径的并是个连通块
一个想法就是枚举连通块中深度最浅的点,然后认为在它子树内的距离\(\le x\)的点都可以在点集内.不过这是错的,因为你很轻松就可以找到这个点两棵不同子树内到他距离为\(x\)的点,而这两个点距离为\(2x\)
所以现在就是要选择一个点集中的特定点,满足所有到它距离\(\le x\)的点满足两两距离\(\le x\).反过来考虑(),我们枚举点集中深度最深的点,深度相同就按照编号排序(其实就是找bfs序最大的点),这时候,所有满足bfs序更小的,到这个点距离\(\le x\)的点都是满足两两距离限制的.假设当前枚举的bfs序最大的点为\(a\),现在考虑\(p,q\)两点,路径\([a,p]\)和路径\([a,q]\)的分叉点为\(b\).可以发现\(p,q\)之中最多有一个在分叉点上方
-
如果两个点都在\(b\)下方,由于\(a\)为当前深度最深的点,那么一定有\(\max(dis(b,p),dis(b,q))\le dis(a,b)\),所以\(dis(b,p)+dis(b,q)=\max(dis(b,p),dis(b,q))+\min(dis(b,p),dis(b,q))\le dis(a,b)+\min(dis(b,p),dis(b,q))=\min(dis(a,p),dis(a,q))\le x\)
-
如果有一个点都在\(b\)上方(假设为\(q\)),因为\(dis(b,q)\le dis(a,b)\),所以\(dis(b,p)+dis(b,q)\le dis(a,b)+dis(b,q)=dis(a,q)\le x\)
所以对于每个点\(a\),如果统计出bfs序比它小的,到它距离\(\le x\)的点个数\(cn_a\),那对于\(ans_i\)有\(\binom{cn_a-1}{i-1}\)的贡献,这个可以把组合数拆开后ntt计算卷积的值
至于\(cn_a\)的计算可以一个log或两个log,如果是一个log,那么可以先算出\(f_i\)表示以某个点(或一条边上的中点)\(i\)为中点,半径为\(\lfloor\frac{x}{2}\rfloor\)的连通块内点数,然后按照bfs序的逆序枚举点\(u\),到\(u\)距离\(\le x\)且深度不大于\(u\)的连通块点数就是\(u\)往上跳\(\lfloor\frac{x}{2}\rfloor\)距离到的点\(v\)的\(f_v\)的值,再考虑bfs序要\(\le u\)的bfs序的话,就每找到一个\(v\)就给\(f_v\)减掉1即可,这样在后面就不会统计到bfs序更大的点了
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int N=6e5+10,M=(1<<20)+10,mod=998244353;
int rd()
{
int x=0,w=1;char ch=0;
while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+(ch^48);ch=getchar();}
return x*w;
}
void ad(int &x,int y){x+=y,x-=x>=mod?mod:0;}
int fpow(int a,int b){int an=1;while(b){if(b&1) an=1ll*an*a%mod;a=1ll*a*a%mod,b>>=1;}return an;}
int ginv(int a){return fpow(a,mod-2);}
int to[N<<1],nt[N<<1],hd[N],tot=1;
void adde(int x,int y)
{
++tot,to[tot]=y,nt[tot]=hd[x],hd[x]=tot;
++tot,to[tot]=x,nt[tot]=hd[y],hd[y]=tot;
}
int n,m,lm,sz[N],f[N],g[N],mx,nsz,rt;
bool ban[N];
void fdrt(int x,int ffa)
{
sz[x]=1;
int nx=0;
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(ban[y]||y==ffa) continue;
fdrt(y,x),sz[x]+=sz[y],nx=max(nx,sz[y]);
}
nx=max(nx,nsz-sz[x]);
if(mx>nx) mx=nx,rt=x;
}
void d1(int x,int ffa,int de)
{
m=max(m,de),g[de]+=x<=n;
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(ban[y]||y==ffa) continue;
d1(y,x,de+1);
}
}
void d2(int x,int ffa,int de)
{
if(de>lm) return;
f[x]+=g[min(m,lm-de)];
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(ban[y]||y==ffa) continue;
d2(y,x,de+1);
}
}
void wk(int x)
{
mx=nsz+1,fdrt(x,0);
x=rt,ban[x]=1,fdrt(x,0);
d1(x,0,0);
for(int i=1;i<=m;++i) g[i]+=g[i-1];
f[x]+=g[min(lm,m)];
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(ban[y]) continue;
d2(y,x,1);
}
memset(g,0,sizeof(int)*(m+1));
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(ban[y]) continue;
m=0,d1(y,x,1);
for(int i=1;i<=m;++i) g[i]+=g[i-1];
for(int i=0;i<=m;++i) g[i]=-g[i];
d2(y,x,1),memset(g,0,sizeof(int)*(m+1));
}
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(ban[y]) continue;
nsz=sz[y],wk(y);
}
}
int st[N],tp,dp[N],ff[N],sq[N];
void d3(int x,int ffa)
{
st[++tp]=x,ff[x]=st[max(1,tp-lm)];
dp[x]=tp;
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(y==ffa) continue;
d3(y,x);
}
--tp;
}
int fac[N],iac[N],W[21],iW[21],rdr[M],aa[M],bb[M];
void ntt(int *a,int n,bool op)
{
int l=0,y;
while((1<<l)<n) ++l;
for(int i=0;i<n;++i)
{
rdr[i]=(rdr[i>>1]>>1)|((i&1)<<(l-1));
if(i<rdr[i]) swap(a[i],a[rdr[i]]);
}
for(int i=1,p=0;i<n;i<<=1,++p)
{
int ww=op?W[p]:iW[p];
for(int j=0;j<n;j+=i<<1)
for(int k=0,w=1;k<i;++k,w=1ll*w*ww%mod)
{
y=1ll*a[j+k+i]*w%mod;
a[j+k+i]=(a[j+k]-y+mod)%mod;
a[j+k]=(a[j+k]+y)%mod;
}
}
if(!op) for(int i=0,w=ginv(n);i<n;++i) a[i]=1ll*a[i]*w%mod;
}
int main()
{
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
for(int i=1,p=0;p<=20;i<<=1,++p)
W[p]=fpow(3,(mod-1)/(i<<1)),iW[p]=ginv(W[p]);
fac[0]=1;
for(int i=1;i<=N-5;++i) fac[i]=1ll*fac[i-1]*i%mod;
iac[N-5]=ginv(fac[N-5]);
for(int i=N-5;i;--i) iac[i-1]=1ll*iac[i]*i%mod;
n=rd(),lm=rd();
for(int i=1;i<n;++i)
{
int x=rd(),y=rd();
adde(x,i+n),adde(y,i+n);
}
nsz=n+n-1,wk(1);
d3(1,0);
for(int i=1;i<=n;++i) sq[i]=i;
sort(sq+1,sq+n+1,[&](int aa,int bb){return dp[aa]>dp[bb];});
for(int i=1;i<=n;++i)
{
int x=sq[i];
--f[ff[x]],++aa[f[ff[x]]];
}
for(int i=0;i<=n;++i) aa[i]=1ll*aa[i]*fac[i]%mod;
for(int i=0;i<=n;++i) bb[i]=iac[n-i];
int len=1;
while(len<=n+n+2) len<<=1;
ntt(aa,len,1),ntt(bb,len,1);
for(int i=0;i<len;++i) aa[i]=1ll*aa[i]*bb[i]%mod;
ntt(aa,len,0);
for(int i=0;i<n;++i) printf("%d ",(int)(1ll*aa[n+i]*iac[i]%mod));
return 0;
}
来源:oschina
链接:https://my.oschina.net/u/4256877/blog/3238343