题目链接
(https://www.luogu.org/problem/P1600)
sol:
这是一个比较好的题,值得一做。首先这个不是一个普通的链覆盖问题,因为在某一条路径上每走一步,权值(即时间)就会变化\(1\)。那么现在就考虑要设定一个标准权,使得每一条路径拥有的那个权值固定,每一个点的权也固定。显然对于向下的路径,这个标准权就是深度;向上的即把深度倒过来。把时间和它作差即可。这样问题就转化为对于一个点\(u\),设它的做差后的值为\(val_u\),现在要统计所有路径权值为\(val_u\),且经过\(u\)的路径条数。考虑处理出\(dfs\)序,用树状数组维护做差分即可。
这个做法挺难写的。首先拆路径时一定要仔细讨论各种情况,其次处理\(lca\)向下的第一个节点也极容易出错。
#include <bits/stdc++.h> #define mp make_pair #define pb push_back #define X first #define Y second using namespace std; inline int read(){ int x=0;char c=getchar(); while(c<'0'||c>'9') c=getchar(); while(c>='0'&&c<='9') x=(x<<1)+(x<<3)+c-'0',c=getchar(); return x; } const int N=300005; int n,m,p[N],dfn[N],cnt,d1[N],d2[N],dep; int ver[N<<1],nxt[N<<1],head[N],tot; vector<pair<int,int> > a1[N<<1],a2[N<<1]; vector<int> tmp1[N<<1],tmp2[N<<1]; int to[N][25],st[N],out[N],ans[N]; inline void add(int x,int y){ver[++tot]=y;nxt[tot]=head[x];head[x]=tot;} void dfs(int x,int la){ dfn[x]=++cnt;d2[x]=d2[la]+1;dep=max(dep,d2[x]);to[x][0]=la; for(int i=1;(1<<i)<d2[x];i++) to[x][i]=to[to[x][i-1]][i-1]; for(int i=head[x];i;i=nxt[i]){ int y=ver[i]; if(y==la) continue; dfs(y,x); } out[x]=cnt; } inline int getlca(int x,int y){ bool im=0; if(d2[x]<d2[y]) swap(x,y),im=1; //交换时不要丢掉路径的方向 for(int w=23;w>=0;w--){ if(d2[x]-(1<<w)<=d2[y]) continue; x=to[x][w]; } if(to[x][0]==y) return x; if(d2[x]>d2[y]) x=to[x][0]; //特判一开始x与y深度相等的情况 for(int w=19;w>=0;w--){ if(d2[x]<=(1<<w)||to[x][w]==to[y][w]) continue; x=to[x][w],y=to[y][w]; } return (im?x:y); } inline int lowbit(int x){return x&(-x);} inline void inst(int x,int y){while(x<=n) st[x]+=y,x+=lowbit(x);} inline int query(int x){int sum=0;while(x>0) sum+=st[x],x-=lowbit(x);return sum;} void work(vector<pair<int,int> > *a,vector<int> *tmp){ for(int i=1;i<(N<<1);i++){ int len=tmp[i].size(); if(len==0) continue; int num=a[i].size(); for(int j=0;j<num;j++){ int x=a[i][j].X,y=a[i][j].Y; inst(dfn[x],1); if(dfn[to[y][0]]!=0) inst(dfn[to[y][0]],-1); //小心等于0 } for(int j=0;j<len;j++){ int x=tmp[i][j]; ans[x]=ans[x]+query(out[x])-query(dfn[x]-1); } for(int j=0;j<num;j++){ int x=a[i][j].X,y=a[i][j].Y; inst(dfn[x],-1); if(dfn[to[y][0]]!=0) inst(dfn[to[y][0]],1); } } } int main(){ n=read();m=read(); for(int i=1;i<n;i++){ int x=read(),y=read(); add(x,y);add(y,x); } for(int i=1;i<=n;i++) p[i]=read(); dfs(1,0); for(int i=1;i<=n;i++){ d1[i]=dep+1-d2[i]; tmp1[p[i]-d1[i]+N].pb(i); tmp2[p[i]-d2[i]+N].pb(i); } for(int i=1;i<=m;i++){ int x=read(),y=read(); if(x==y){a1[-d1[x]+N].pb(mp(x,y));continue;} //这个情况要特判 int lca=getlca(x,y); if(to[lca][0]==x) a2[-d2[x]+N].pb(mp(y,x)); else if(to[lca][0]==y) a1[-d1[x]+N].pb(mp(x,y)); else{ a1[-d1[x]+N].pb(mp(x,to[lca][0])); a2[d2[x]-d2[to[lca][0]]+1-d2[lca]+N].pb(mp(y,lca)); } } work(a1,tmp1);work(a2,tmp2); for(int i=1;i<=n;i++) printf("%d ",ans[i]); return 0; }