题目
首先我们有这样一个暴力:
对于每个点,我们把覆盖到了它的链的两段拿出来作为关键点,那么这个点能够到达的点就是这些关键点作为极远点的生成树上的点。
我们把所有的关键点按dfs序排序,从小到大枚举,那么对于一个关键点\(d_i\),它会造成\(dep_{d_i}-dep_{lca(d_i,d_{i-1})}\)的贡献。(这里建议画图理解一下)
为了方便,我们强制把\(1\)作为关键点,那么最后减去\(dep_{lca(d_1,d_m)}\)(假设总共有\(m\)个关键点)即所有关键点的\(lca\)的深度即可。
我们考虑用线段树维护这个过程。
以dfs序为下标,线段树的节点记录当前区间点集的生成树大小,dfs序最小的和最大的点。我们在pushup的时候完成减去\(dep_{lca(u,v)}\)(\(u\)为线段树左孩子中dfs序最大的点,\(v\)为线段树右孩子中dfs序最小的点)的操作。查询的时候用根节点的答案减去所有点的\(lca\)的深度即可。
我们再考虑优化这个暴力。
对于一条路径\(s-t\),我们会把所有\(s-t\)上的节点都选\(s,t\)两点,这个可以用树上查分解决。
然后我们要实现把儿子的信息传给父亲,线段树合并解决。
因为总插入的信息是\(n\)级别的,所以线段树合并复杂度为\(O(n\log n)\)。
因为我们要求\(n\log n\)次\(lca\),所以用dfs序+ST表求即可做到\(O(n\log n)\)。
总的复杂度还是\(O(n\log n)\)。不过常数很大。
#include<bits/stdc++.h> #define lc ls[p] #define rc rs[p] #define ll long long #define pb push_back #define mid ((l+r)>>1) using namespace std; namespace IO { char ibuf[(1<<21)+1],*iS,*iT; char Get(){return (iS==iT? (iT=(iS=ibuf)+fread(ibuf,1,(1<<21)+1,stdin),(iS==iT? EOF:*iS++)):*iS++);} int read(){int x=0,c=Get();while(!isdigit(c))c=Get();while(isdigit(c))x=x*10+c-48,c=Get();return x;} } using namespace IO; void swap(int &a,int &b){a^=b^=a^=b;} const int N=100007,M=N<<6; int n,m,T,fa[N],dep[N],dfn[N],Log[N<<1],st[20][N<<1],cnt,root[N<<1],sum[M],ls[M],rs[M],f[M],s[M],t[M]; ll ans;vector<int>E[N],del[N]; void add(int u,int v){E[u].pb(v),E[v].pb(u);} void dfs(int u) { dep[u]=dep[fa[u]]+1,st[0][dfn[u]=++T]=u; for(int v:E[u]) if(v^fa[u]) fa[v]=u,dfs(v),st[0][++T]=u; } void init() { for(int i=2;i<=T;++i) Log[i]=Log[i>>1]+1; for(int i=1,j,u,v;i<=Log[T];++i) for(j=1;j+(1<<i)-1<=T;++j) st[i][j]=dep[u=st[i-1][j]]<dep[v=st[i-1][j+(1<<(i-1))]]? u:v; } int lca(int u,int v) { if((u=dfn[u])>(v=dfn[v])) swap(u,v); int d=Log[v-u+1]; return dep[u=st[d][u]]<dep[v=st[d][v-(1<<d)+1]]? u:v; } void pushup(int p) { f[p]=f[lc]+f[rc]-dep[lca(t[lc],s[rc])]; s[p]=s[lc]? s[lc]:s[rc]; t[p]=t[rc]? t[rc]:t[lc]; } int query(int p){return f[p]-dep[lca(s[p],t[p])];} void update(int &p,int l,int r,int x,int v) { if(!p) p=++cnt; if(l==r) return (void)(sum[p]+=v,(f[p]=sum[p]? dep[x]:0),(s[p]=t[p]=sum[p]? x:0)); (dfn[x]<=mid? update(lc,l,mid,x,v):update(rc,mid+1,r,x,v)),pushup(p); } void merge(int &u,int v,int l,int r) { if(!u||!v) return (void)(u|=v); if(l==r) return (void)(sum[u]+=sum[v],f[u]|=f[v],s[u]|=s[v],t[u]|=t[v]); merge(ls[u],ls[v],l,mid),merge(rs[u],rs[v],mid+1,r),pushup(u); } void in() { int u,v,l=lca(u=read(),v=read()); update(root[u],1,T,u,1),update(root[u],1,T,v,1); update(root[v],1,T,u,1),update(root[v],1,T,v,1); del[l].pb(u),del[l].pb(v),del[fa[l]].pb(u),del[fa[l]].pb(v); } void solve(int u) { for(int v:E[u]) if(v^fa[u]) solve(v); for(int x:del[u]) update(root[u],1,T,x,-1); ans+=query(root[u]),merge(root[fa[u]],root[u],1,T); } int main() { n=read(),m=read(); for(int i=1,u,v;i<n;++i) u=read(),v=read(),add(u,v); dfs(1),init(); while(m--) in(); solve(1); return !printf("%lld",ans>>1); }