题意
给定一颗边权为1的树,点权为v,给定m条从s到t的路径,对于每个点,求\(ans_i=\Sigma_{j=1}^m [dist(s_j,i)==v[i]]\)
思路
有两种可能的做法,一种是把路径全加进去,再每一个点求\(ans\),另一种是一条一条加路径,每次求贡献。而这道题用的是第一种
对于每一条路径,将\(s->t\)的路径拆分为\(s->lca->t\),\(s->lca\)为左路径,\(lca->t\)为右路径,对于左路径上的点\(i\),左路径对它有贡献当且仅当\(dep[s]-dep[i]==v[i]\),将\(i\)项移动到一边,就有\(dep[s]==dep[i]+v[i]\),这样就可以把路径全部加进去再求\(i\)点的贡献。右路径同理有\(dep[s]-2*dep[lca]==v[i]-dep[i]\)
\(dfs\)整颗树,从下向上回溯时求\(ans\),对于回溯时遇到的点\(i\),加入以它为\(s/t\)的左右路径,求一次\(ans\),退出时减去以它为\(lca\)的左右路径即可,然后就发现过不了样例
(假设当前在以\(u\)为根的子树中\(dfs\))因为在\(u\)的某一珂子树\(v\)中的时候,\(u\)的其他子树里面的路径也被记录下来了,这时就会导致一条路径对不在它上面的点做出贡献。
解决方法:因为求的是满足上面两个约束的路径条数,那么在递归进入\(v\)子树加入里面的路径之前,先减去此时满足约束条件的路径条数,这样就保证了做出贡献的都是\(v\)子树里面出发的路径
另外,如果\(i\)是路径的\(lca\)的话可能被计算两次(左右路径各一次),所以要减1,即当满足\(v[lca]==dep[s]-dep[lca]\)的时候减1
Code:
#include<bits/stdc++.h> #define N 900005 #define M 300005 using namespace std; const int temp = 300000; int n,m; int dep[M],w[M],fa[M][18],ans[M]; int lsum[N],rsum[N];//桶 vector<int> L[M],R[M];//lca在上面,出 vector<int> l[M],r[M];//s,t在下面,入 struct Edge { int next,to; }edge[M<<1];int head[M],cnt=1; void add_edge(int from,int to) { edge[++cnt].next=head[from]; edge[cnt].to=to; head[from]=cnt; } template <class T> void read(T &x) { char c;int sign=1; while((c=getchar())>'9'||c<'0') if(c=='-') sign=-1; x=c-48; while((c=getchar())>='0'&&c<='9') x=x*10+c-48; x*=sign; } void dfs(int rt) { dep[rt]=dep[fa[rt][0]]+1; for(int i=head[rt];i;i=edge[i].next) { int v=edge[i].to; if(v==fa[rt][0]) continue; fa[v][0]=rt; for(int i=1;i<18;++i) fa[v][i]=fa[fa[v][i-1]][i-1]; dfs(v); } } int lca(int x,int y) { if(dep[x]<dep[y]) swap(x,y); for(int i=17;i>=0;--i) if(dep[fa[x][i]]>=dep[y]) x=fa[x][i]; if(x==y) return x; for(int i=17;i>=0;--i) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } void DFS(int rt) { ans[rt]-=lsum[dep[rt]+w[rt]+temp]+rsum[w[rt]-dep[rt]+temp]; for(int i=head[rt];i;i=edge[i].next) { int v=edge[i].to; if(v==fa[rt][0]) continue; DFS(v); } //加入rt for(int i=0;i<(int)l[rt].size();++i) { int val=l[rt][i]+temp; ++lsum[val]; } for(int i=0;i<(int)r[rt].size();++i) { int val=r[rt][i]+temp; ++rsum[val]; } ans[rt]+=lsum[dep[rt]+w[rt]+temp]+rsum[w[rt]-dep[rt]+temp]; //删除以rt为lca的链 for(int i=0;i<(int)L[rt].size();++i) { int val=L[rt][i]+temp; --lsum[val]; } for(int i=0;i<(int)R[rt].size();++i) { int val=R[rt][i]+temp; --rsum[val]; } } int main() { read(n);read(m); for(int i=1;i<n;++i) { int x,y; read(x);read(y); add_edge(x,y); add_edge(y,x); } dfs(1); for(int i=1;i<=n;++i) read(w[i]); for(int i=1;i<=m;++i) { int S,T; read(S);read(T); int lc=lca(S,T); if(w[lc]==dep[S]-dep[lc]) ans[lc]--; L[lc].push_back(dep[S]); R[lc].push_back(dep[S]-2*dep[lc]); l[S].push_back(dep[S]); r[T].push_back(dep[S]-2*dep[lc]); } DFS(1); for(int i=1;i<=n;++i) printf("%d ",ans[i]); return 0; }