前几天讲了虚树,今天就来切一道虚树的题喽。
题目概述:
给你一棵树,有\(q\)个询问,每个询问给出\(k\)个点,两两连边,边的长度为其在树上的距离,求出这\(k\)个点连边总长度、最短的一条边以及最长的一条边。
其中\(\sum_{i=1}^qk<=5^4\)
大体思路:
显然是道虚树的题,那就先构建棵虚树。
对于第二三两个询问,我们都可以通过一个非常简单的DP来实现,每个询问分别用两个数组求出以它为根,第一、第二(大/小)的路径的值,但是要注意这里的两个值不能都是从同一个子树中转移过来的。
关于第一个询问,我们考虑每一条边对于答案的贡献,不难发现其被加的次数为\(siz[u]\times siz[v]\),那么我们把每一条边的贡献加起来即可。
具体实现 (坑爹细节) :
对于非询问点(即通过\(LCA\)关系而加入虚树的点),我们不妨称之为关系户。
关系户不能进入\(siz\)的统计(显然)。
关系户只能作为连接点,即只能以第一、第二(大/小)的路径的值相加来更新答案。(显然)
虚树的根为所有点加入完毕后,剩下的栈中的那个栈顶。(显然)
代码:
#include <bits/stdc++.h> using namespace std; const int N=1e6+5,inf=1e9; int f[N][21],fir[N],sec[N],Fir[N],Sec[N],nxt[N<<1],vet[N<<1],head[N],dep[N],dfn[N],siz[N],stk[N],a[N]; int n,m,ans1,ans2,Q,u,v,tot,tim,top,rt; long long ans0; bool flag[N]; struct Edge{ int v,d; }; vector <Edge> E[N]; void add(int u,int v){ nxt[++tot]=head[u]; vet[tot]=v; head[u]=tot; } void dfs0(int u,int fa){ f[u][0]=fa,dep[u]=dep[fa]+1,dfn[u]=++tim; for (int i=1;i<=20;i++) f[u][i]=f[f[u][i-1]][i-1]; for (int i=head[u];i;i=nxt[i]){ int v=vet[i]; if (v==fa) continue; dfs0(v,u); } } int lca(int a,int b){ if (dep[a]<dep[b]) swap(a,b); if (b==0) return 0; int res=inf; for (int i=20;i>=0;i--) if (dep[f[a][i]]>=dep[b]) a=f[a][i]; if (a==b) return a; for (int i=20;i>=0;i--) if (f[a][i]!=f[b][i]) a=f[a][i],b=f[b][i]; return f[a][0]; } int dis(int a,int b){return dep[a]+dep[b]-dep[lca(a,b)]*2;} void ins(int x){ int l=lca(stk[top],x); while (top>1&&dep[stk[top-1]]>dep[l]){ E[stk[top]].push_back((Edge){stk[top-1],dis(stk[top],stk[top-1])}); E[stk[top-1]].push_back((Edge){stk[top],dis(stk[top],stk[top-1])}); top--; } if (dep[l]<dep[stk[top]]){ E[l].push_back((Edge){stk[top],dis(stk[top],l)}); E[stk[top]].push_back((Edge){l,dis(stk[top],l)}); top--; } if (stk[top]!=l) stk[++top]=l; stk[++top]=x; } void dfs(int u,int fa){ Fir[u]=Sec[u]=inf,fir[u]=sec[u]=siz[u]=0; if (flag[u]) siz[u]=1,Fir[u]=0; for (int i=0;i<E[u].size();i++){ int v=E[u][i].v; if (v==fa) continue; dfs(v,u); ans0+=1ll*E[u][i].d*siz[v]*(m-siz[v]); siz[u]+=siz[v]; int x=Fir[v]+E[u][i].d; if (x<=Fir[u]) Sec[u]=Fir[u],Fir[u]=x; else if (x<Sec[u]) Sec[u]=x; x=fir[v]+E[u][i].d; if (x>=fir[u]) sec[u]=fir[u],fir[u]=x; else if (x>sec[u]) sec[u]=x; } ans1=min(ans1,Fir[u]+Sec[u]); ans2=max(ans2,fir[u]+sec[u]); E[u].clear(); } bool cmp(int a,int b){return dfn[a]<dfn[b];} int main(){ scanf("%d",&n); for (int i=1;i<n;i++){ scanf("%d%d",&u,&v); add(u,v),add(v,u); } dfs0(1,0); scanf("%d",&Q); while (Q--){ ans0=ans2=0,ans1=inf; scanf("%d",&m); if (m==1){ puts("0 0 0"); continue; } top=0; for (int i=1;i<=m;i++) scanf("%d",&a[i]),flag[a[i]]=1; rt=a[1]; sort(a+1,a+1+m,cmp); for (int i=1;i<=m;i++) ins(a[i]); while (top>1){ E[stk[top]].push_back((Edge){stk[top-1],dis(stk[top],stk[top-1])}); E[stk[top-1]].push_back((Edge){stk[top],dis(stk[top],stk[top-1])}); top--; } dfs(stk[top],0); for (int i=1;i<=m;i++) flag[a[i]]=0; printf("%lld %d %d\n",ans0,ans1,ans2); } return 0; }