$solution:$
考虑二元组 $(S,T)$ 对 $u$ 点的贡献。
若 $S$ 在 $u$ 子树上 ( $T$ 不在),且满足 $dep_u+w_u=dep_S$ 就可以对 $u$ 作贡献。
若 $T$ 在 $u$ 子树上 ( $S$ 不在) ,且满足 $w_u-dep_u=dep_S-2\times dep_{lca}$ 就可以对 $u$ 作贡献。
所以只要将 $(S,T)$ 拆成 $(S,lca),(lca,T)$ 即可。
对于计算直接线段树合并与简单差分即可。
时间复杂度 $O(m\log m)$
#include<iostream> #include<cstring> #include<cstdio> #include<algorithm> using namespace std; inline int read(){ int f=1,ans=0;char c=getchar(); while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();} while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();} return f*ans; } const int MAXN=300001; struct node{ int u,v,nex; }x[MAXN<<1]; int n,m,fa[MAXN][21],dep[MAXN],head[MAXN],son[MAXN],cnt,w[MAXN]; void add(int u,int v){ x[cnt].u=u,x[cnt].v=v,x[cnt].nex=head[u],head[u]=cnt++; } void dfs(int u,int fath){ fa[u][0]=fath; dep[u]=dep[fath]+1; for(int i=1;(1<<i)<=dep[u];i++) fa[u][i]=fa[fa[u][i-1]][i-1]; for(int i=head[u];i!=-1;i=x[i].nex){ if(x[i].v==fath) continue; dfs(x[i].v,u); }return; } int lca(int u,int v){ if(dep[u]<dep[v]) swap(u,v); for(int i=20;i>=0;i--) if(dep[u]-(1<<i)>=dep[v]) u=fa[u][i]; if(u==v) return u; for(int i=20;i>=0;i--){ if(fa[u][i]==fa[v][i]) continue; u=fa[u][i],v=fa[v][i]; }return fa[u][0]; } int Ans[MAXN*30],rt[MAXN],tot,ls[MAXN*30],rs[MAXN*30]; struct Segment{ void clear(){memset(Ans,0,sizeof(Ans)),memset(ls,0,sizeof(ls)),memset(rs,0,sizeof(rs)),memset(rt,0,sizeof(rt));tot=0;} void update(int &tr,int l,int r,int px,int w){ if(!tr) tr=++tot; Ans[tr]+=w; if(l==r) return; int mid=l+r>>1; if(px<=mid) update(ls[tr],l,mid,px,w); if(mid<px) update(rs[tr],mid+1,r,px,w); return; } int merge(int p,int q,int l,int r){ if(!p||!q) return p+q; if(l==r){ Ans[p]+=Ans[q]; return p; } int mid=l+r>>1; ls[p]=merge(ls[p],ls[q],l,mid); rs[p]=merge(rs[p],rs[q],mid+1,r); Ans[p]+=Ans[q];return p; } void add(int ps,int w,int opt){ if(!ps) return; update(rt[ps],1,5*n,3*n+w,opt);return; } void Merge(int p,int q){ rt[p]=merge(rt[p],rt[q],1,5*n);return; } int query(int k,int l,int r,int px){ if(!k) return 0; if(l==r) return Ans[k]; int mid=l+r>>1; if(px<=mid) return query(ls[k],l,mid,px); if(mid<px) return query(rs[k],mid+1,r,px); } int Query(int ps,int w){ return query(rt[ps],1,5*n,3*n+w); } }Segment; int tot1,tot2; struct Up{ int S,T,Lca; }G1[MAXN]; struct Down{ int S,T,Lca,s; }G2[MAXN]; struct spe{ int S,T; }G[MAXN]; int Ans1[MAXN],Ans2[MAXN],Ans3[MAXN]; void dfs1(int u,int fath){ for(int i=head[u];i!=-1;i=x[i].nex){ if(x[i].v==fath) continue; dfs1(x[i].v,u); Segment.Merge(u,x[i].v); }Ans1[u]+=Segment.Query(u,dep[u]+w[u]); return; } void dfs2(int u,int fath){ for(int i=head[u];i!=-1;i=x[i].nex){ if(x[i].v==fath) continue; dfs2(x[i].v,u); Segment.Merge(u,x[i].v); }Ans2[u]+=Segment.Query(u,w[u]-dep[u]); } int Qdis(int u,int v){return dep[u]+dep[v]-2*dep[lca(u,v)];} int main(){ // freopen("make.in","r",stdin); memset(head,-1,sizeof(head)); n=read(),m=read(); for(int i=1;i<n;i++){ int u=read(),v=read(); add(u,v),add(v,u); } for(int i=1;i<=n;i++) w[i]=read(); dfs(1,0); for(int i=1;i<=m;i++){ int S=read(),T=read(); G[i].S=S,G[i].T=T; int LCA=lca(S,T); G1[++tot1].S=S,G1[tot1].T=LCA,G1[tot1].Lca=LCA; G2[++tot2].S=LCA,G2[tot2].T=T,G2[tot2].Lca=LCA,G2[tot2].s=S; if(Qdis(S,LCA)==w[LCA]) Ans1[LCA]--; } for(int i=1;i<=tot1;i++){ Segment.add(G1[i].S,dep[G1[i].S],1); Segment.add(fa[G1[i].T][0],dep[G1[i].S],-1); } dfs1(1,0); Segment.clear(); for(int i=1;i<=tot2;i++){ Segment.add(G2[i].T,dep[G2[i].s]-2*dep[G2[i].Lca],1); Segment.add(fa[G2[i].S][0],dep[G2[i].s]-2*dep[G2[i].Lca],-1); } dfs2(1,0); for(int i=1;i<=n;i++) printf("%d ",Ans1[i]+Ans2[i]);printf("\n"); return 0; }/* 3 1 2 1 3 1 1 1 3 2 3 1 3 */