[HEOI 2014] 大工程

末鹿安然 提交于 2019-12-03 09:57:56

前几天讲了虚树,今天就来切一道虚树的题喽。

题目概述:

给你一棵树,有\(q\)个询问,每个询问给出\(k\)个点,两两连边,边的长度为其在树上的距离,求出这\(k\)个点连边总长度、最短的一条边以及最长的一条边。

其中\(\sum_{i=1}^qk<=5^4\)

大体思路:

显然是道虚树的题,那就先构建棵虚树。

对于第二三两个询问,我们都可以通过一个非常简单的DP来实现,每个询问分别用两个数组求出以它为根,第一、第二(大/小)的路径的值,但是要注意这里的两个值不能都是从同一个子树中转移过来的。

关于第一个询问,我们考虑每一条边对于答案的贡献,不难发现其被加的次数为\(siz[u]\times siz[v]\),那么我们把每一条边的贡献加起来即可。

具体实现 (坑爹细节)

对于非询问点(即通过\(LCA\)关系而加入虚树的点),我们不妨称之为关系户。

关系户不能进入\(siz\)的统计(显然)。

关系户只能作为连接点,即只能以第一、第二(大/小)的路径的值相加来更新答案。(显然)

虚树的根为所有点加入完毕后,剩下的栈中的那个栈顶。(显然)

代码:

#include <bits/stdc++.h>
using namespace std;
const int N=1e6+5,inf=1e9;
int f[N][21],fir[N],sec[N],Fir[N],Sec[N],nxt[N<<1],vet[N<<1],head[N],dep[N],dfn[N],siz[N],stk[N],a[N];
int n,m,ans1,ans2,Q,u,v,tot,tim,top,rt;
long long ans0;
bool flag[N];
struct Edge{
    int v,d;
};
vector <Edge> E[N];
void add(int u,int v){
    nxt[++tot]=head[u];
    vet[tot]=v;
    head[u]=tot;
}
void dfs0(int u,int fa){
    f[u][0]=fa,dep[u]=dep[fa]+1,dfn[u]=++tim;
    for (int i=1;i<=20;i++)
        f[u][i]=f[f[u][i-1]][i-1];
    for (int i=head[u];i;i=nxt[i]){
        int v=vet[i];
        if (v==fa) continue;
        dfs0(v,u);
    }
}
int lca(int a,int b){
    if (dep[a]<dep[b]) swap(a,b);
    if (b==0) return 0;
    int res=inf;
    for (int i=20;i>=0;i--)
        if (dep[f[a][i]]>=dep[b]) a=f[a][i];
    if (a==b) return a;
    for (int i=20;i>=0;i--)
        if (f[a][i]!=f[b][i]) 
            a=f[a][i],b=f[b][i];
    return f[a][0];
}
int dis(int a,int b){return dep[a]+dep[b]-dep[lca(a,b)]*2;}
void ins(int x){
    int l=lca(stk[top],x);
    while (top>1&&dep[stk[top-1]]>dep[l]){
        E[stk[top]].push_back((Edge){stk[top-1],dis(stk[top],stk[top-1])});
        E[stk[top-1]].push_back((Edge){stk[top],dis(stk[top],stk[top-1])});
        top--;
    }
    if (dep[l]<dep[stk[top]]){
        E[l].push_back((Edge){stk[top],dis(stk[top],l)});
        E[stk[top]].push_back((Edge){l,dis(stk[top],l)});
        top--;
    }
    if (stk[top]!=l) stk[++top]=l;
    stk[++top]=x;
}
void dfs(int u,int fa){
    Fir[u]=Sec[u]=inf,fir[u]=sec[u]=siz[u]=0;
    if (flag[u]) siz[u]=1,Fir[u]=0;
    for (int i=0;i<E[u].size();i++){
        int v=E[u][i].v;
        if (v==fa) continue;
        dfs(v,u);
        ans0+=1ll*E[u][i].d*siz[v]*(m-siz[v]);
        siz[u]+=siz[v];
        int x=Fir[v]+E[u][i].d;
        if (x<=Fir[u])
            Sec[u]=Fir[u],Fir[u]=x;
        else
            if (x<Sec[u]) Sec[u]=x;
        x=fir[v]+E[u][i].d;
        if (x>=fir[u])
            sec[u]=fir[u],fir[u]=x;
        else
            if (x>sec[u]) sec[u]=x;
    }
    ans1=min(ans1,Fir[u]+Sec[u]);
    ans2=max(ans2,fir[u]+sec[u]);
    E[u].clear();
}
bool cmp(int a,int b){return dfn[a]<dfn[b];}
int main(){
    scanf("%d",&n);
    for (int i=1;i<n;i++){
        scanf("%d%d",&u,&v);
        add(u,v),add(v,u);
    }
    dfs0(1,0);
    scanf("%d",&Q);
    while (Q--){
        ans0=ans2=0,ans1=inf;
        scanf("%d",&m);
        if (m==1){
            puts("0 0 0");
            continue;
        }
        top=0;
        for (int i=1;i<=m;i++)
            scanf("%d",&a[i]),flag[a[i]]=1;
        rt=a[1];
        sort(a+1,a+1+m,cmp);
        for (int i=1;i<=m;i++) ins(a[i]);
        while (top>1){
            E[stk[top]].push_back((Edge){stk[top-1],dis(stk[top],stk[top-1])});
            E[stk[top-1]].push_back((Edge){stk[top],dis(stk[top],stk[top-1])});
            top--;
        }
        dfs(stk[top],0);
        for (int i=1;i<=m;i++) flag[a[i]]=0;
        printf("%lld %d %d\n",ans0,ans1,ans2);
    }
    return 0;
}
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!