参考:http://blog.csdn.net/nixinyis/article/details/65445466
【简介】
点分治是一类用于处理树上路径点权统计问题的算法,其利用重心的性质降低复杂度。
【什么是重心】
某个其所有的子树中最大的子树节点数最少的点被称为重心,删去重心后,生成的多棵树尽可能平衡。
【重心的性质】
①重心其所有子树的大小都不超过$\frac{n}{2}$。
②树中所有点到某个点的距离和中,到树的重心的距离和是最小的,如果有两个重心,那么到它们的距离和相同。
③把两棵树通过两个点相连得到一棵新的树,新的树的重心必定在连接两棵树的重心的路径上。
④一棵树添加或删除一个节点,树的重心最多会移动一条边的位置。
点分治的复杂度基于重心的第一个性质。
【点分治】
点分治是对于每一棵子树,都求出它的重心,并且以重心为根跑一遍这棵子树并统计经过重心的路径,因为我们知道重心所有子树的大小都只有原树的一半,也就是我们这么做最多只会递归$logn$层,若一层的复杂度$O(f(n))$,则总的时间复杂度为$O(f(n)logn)$。
接下来以bzoj1468: Tree为例题讲一下点分治。
对于这道题,显然用上方说的方法,对于每一个子树求出dep,排序后两端指针往中间靠拢统计即可。但是可能会统计重复,如下图,如果红色路径加上父亲边的两倍还小于K,那么也会被根节点统计,所以我们还得删去加上这条父亲边之后的贡献。
【点分治的流程】(以bzoj1468为标准,可能部分内容因题目而异)
点分治由getroot(求重心),calc(计算贡献),getdep(求某个子树以重心为根的所有点的深度),dfs构成。
首先是getroot部分,求重心就不多提了,只要找到一个点满足最大的子树最小即可。
getroot的一些注意事项:getroot之前把root清零,把sum改成子树大小,getroot的时候不访问当前点的父亲或者是dfs已经访问过的点,记得搜到每一个点的时候都要先把mx清零。
getdep部分,calc的时候需要先对这棵子树先getdep,用d数组求出以子树根为根的子树的所有点的深度,并且再用一个dep数组把子树里所有点的深度都塞进去,注意dep只存这棵子树的点的深度,而d是以节点编号为下标的。同样不访问当前点的父亲或者是dfs已经访问过的点。
calc部分,形参有子树根和init,init表示子树根的初始深度,用于上面所说的去重。将子树根初始深度设好,然后getdep,接着就可以处理这个子树的信息了。
最后是dfs部分,每次先对当前点calc(x, 0)计算贡献,然后对每一个son,先把答案去掉calc(son, e[i].dis)的贡献,然后对子树getroot求出重心,接着搜子树的重心即可。
代码如下:
#include<iostream> #include<cstring> #include<cstdlib> #include<cstdio> #include<algorithm> #define ll long long using namespace std; const int maxn=40010, inf=1e9; struct poi{int too, dis, pre;}e[maxn<<1]; int n, K, x, y, z, sum, ans, tot, root, cnt, tott; int size[maxn], mx[maxn], d[maxn], dep[maxn], last[maxn]; bool v[maxn]; inline void read(int &k) { int f=1; k=0; char c=getchar(); while(c<'0' || c>'9') c=='-'&&(f=-1), c=getchar(); while(c<='9' && c>='0') k=k*10+c-'0', c=getchar(); k*=f; } inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;} void getroot(int x, int fa) { size[x]=1; mx[x]=0; for(int i=last[x], too;i;i=e[i].pre) if(!v[too=e[i].too] && too!=fa) { getroot(too, x); size[x]+=size[too]; mx[x]=max(mx[x], size[too]); } mx[x]=max(mx[x], sum-size[x]); if(mx[x]<mx[root]) root=x; } void getdep(int x, int fa) { dep[++cnt]=d[x]; for(int i=last[x], too;i;i=e[i].pre) if(!v[too=e[i].too] && too!=fa) { d[too]=d[x]+e[i].dis; getdep(too, x); } } int calc(int x, int init) { d[x]=init; cnt=0; getdep(x, 0); sort(dep+1, dep+cnt+1); int l=1, r=cnt, sum=0; while(l<r) if(dep[l]+dep[r]<=K) sum+=r-l, l++; else r--; return sum; } void dfs(int x) { ans+=calc(x, 0); v[x]=1; for(int i=last[x], too;i;i=e[i].pre) if(!v[too=e[i].too]) { ans-=calc(too, e[i].dis); sum=size[too]; root=0; getroot(too, 0); dfs(root); } } int main() { read(n); for(int i=1;i<n;i++) read(x), read(y), read(z), add(x, y, z), add(y, x, z); read(K); sum=n; mx[0]=inf; getroot(1, 0); dfs(root); printf("%d\n", ans); }
还有一种写法是对每个重心的子树单独搜,对每个子树统计与搜过子树的方案数,这样子统计得到的路径必定经过重心。
【例题】
例1:bzoj2152: 聪聪可可
这题一眼做法是$O(n)$的,但这里只说点分治的做法。
和上方的例题差不多,只是改成每次求子树里深度%3的个数而已,最后答案就是$sum[0]*sum[0]+sum[1]*sum[2]$。
#include<iostream> #include<cstring> #include<cstdlib> #include<cstdio> #define MOD(x) ((x)>=3?(x)-3:(x)) using namespace std; const int maxn=500010, inf=1e9; struct poi{int too, dis, pre;}e[maxn<<1]; int n, x, y, z, tot, sum, root, ans; int last[maxn], mx[maxn], size[maxn], cnt[3], d[maxn]; bool v[maxn]; inline void read(int &k) { int f=1; k=0; char c=getchar(); while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar(); while(c<='9' && c>='0') k=k*10+c-'0', c=getchar(); k*=f; } inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;} inline int gcd(int a, int b){return b?gcd(b, a%b):a;} void getroot(int x, int fa) { mx[x]=0; size[x]=1; for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa && !v[too]) { getroot(too, x); size[x]+=size[too]; mx[x]=max(mx[x], size[too]); } mx[x]=max(sum-size[x], mx[x]); if(mx[x]<mx[root]) root=x; } void getdep(int x, int fa) { cnt[d[x]]++; for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa && !v[too]) { d[too]=MOD(d[x]+e[i].dis); getdep(too, x); } } int calc(int x, int init) { d[x]=init; cnt[0]=cnt[1]=cnt[2]=0; getdep(x, 0); return cnt[0]*cnt[0]+(cnt[1]*cnt[2]<<1); } void dfs(int x) { ans+=calc(x, 0); v[x]=1; for(int i=last[x], too;i;i=e[i].pre) if(!v[too=e[i].too]) { ans-=calc(too, e[i].dis); sum=size[too]; root=0; getroot(too, 0); dfs(root); } } int main() { read(n); for(int i=1;i<n;i++) read(x), read(y), read(z), z%=3, add(x, y, z), add(y, x, z); sum=n; mx[0]=inf; getroot(1, 0); dfs(root); printf("%d/%d\n", ans/gcd(n*n, ans), n*n/gcd(n*n, ans)); }
例2:bzoj2599: [IOI2011]Race
这题对于点分治的初学者来说很容易出错(比如我T T)
因为对于每一个子树我们只能统计经过其重心的路径,而这题是求最小边长,按上面两题的做法做就不可行了。对于每一个子树查询其答案的时候,重心的每一个子树先进行更新答案,再记录其深度的值,这样求得的路径就必定经过重心了,然后再统计重心为端点的答案。
#include<iostream> #include<cstring> #include<cstdlib> #include<cstdio> using namespace std; const int maxn=1000010, inf=1e9; struct poi{int too, dis, pre;}e[maxn<<1]; int n, x, y, z, tot, sum, root, ans, cnt, K; int last[maxn], mx[maxn], size[maxn], dep[maxn], d[maxn], ecnt[maxn], mn[maxn]; bool v[maxn]; inline void read(int &k) { int f=1; k=0; char c=getchar(); while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar(); while(c<='9' && c>='0') k=k*10+c-'0', c=getchar(); k*=f; } inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;} void getroot(int x, int fa) { mx[x]=0; size[x]=1; for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa && !v[too]) { getroot(too, x); size[x]+=size[too]; mx[x]=max(mx[x], size[too]); } mx[x]=max(sum-size[x], mx[x]); if(mx[x]<mx[root]) root=x; } void getans(int x, int fa) { dep[++cnt]=d[x]; if(d[x]<=K) ans=min(ans, mn[K-d[x]]+ecnt[x]); if(d[x]==K) ans=min(ans, ecnt[x]); for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa && !v[too]) { d[too]=d[x]+e[i].dis; ecnt[too]=ecnt[x]+1; getans(too, x); } } void update(int x, int fa) { if(d[x]<=K) mn[d[x]]=min(mn[d[x]], ecnt[x]); for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa && !v[too]) update(too, x); } void calc(int x) { d[x]=ecnt[x]=cnt=0; for(int i=last[x], too;i;i=e[i].pre) if(!v[too=e[i].too]) d[too]=e[i].dis, ecnt[too]=1, getans(too, x), update(too, x); for(int i=1;i<=cnt;i++) if(K>=dep[i]) mn[dep[i]]=inf; } void dfs(int x) { calc(x); v[x]=1; for(int i=last[x], too;i;i=e[i].pre) if(!v[too=e[i].too]) { sum=size[too]; root=0; getroot(too, 0); dfs(root); } } int main() { read(n); read(K); for(int i=1;i<n;i++) read(x), read(y), read(z), x++, y++, add(x, y, z), add(y, x, z); memset(mn, 32, sizeof(mn)); ans=inf; sum=n; mx[0]=inf; getroot(1, 0); dfs(root); printf("%d\n", ans>n?-1:ans); return 0; }
例3:bzoj1316: 树上的询问
这题就是在点分的时候扫一遍所有询问,查询是否有这个len即可,有个坑是len==0的时候答案为Yes。
#include<iostream> #include<cstring> #include<cstdlib> #include<cstdio> using namespace std; const int maxn=1000010, inf=1e9, maxl=1000010; struct poi{int too, dis, pre;}e[maxn<<1]; int n, x, y, z, tot, sum, root, cnt, Q; int last[maxn], mx[maxn], size[maxn], dep[maxn], d[maxn], len[maxn], ans[maxn]; bool v[maxn], vis[maxn]; inline void read(int &k) { int f=1; k=0; char c=getchar(); while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar(); while(c<='9' && c>='0') k=k*10+c-'0', c=getchar(); k*=f; } inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;} void getroot(int x, int fa) { mx[x]=0; size[x]=1; for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa && !v[too]) { getroot(too, x); size[x]+=size[too]; mx[x]=max(mx[x], size[too]); } mx[x]=max(sum-size[x], mx[x]); if(mx[x]<mx[root]) root=x; } void getans(int x, int fa) { dep[++cnt]=d[x]; for(int i=1;i<=Q;i++) if(len[i]>=d[x]) if(vis[len[i]-d[x]]) ans[i]=1; for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa && !v[too]) d[too]=d[x]+e[i].dis, getans(too, x); } void update(int x, int fa) { if(d[x]<=maxl) vis[d[x]]=1; for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa && !v[too]) update(too, x); } void calc(int x) { d[x]=cnt=0; for(int i=last[x], too;i;i=e[i].pre) if(!v[too=e[i].too]) d[too]=e[i].dis, getans(too, x), update(too, x); for(int i=1;i<=cnt;i++) if(maxl>=dep[i]) vis[dep[i]]=0; } void dfs(int x) { calc(x); v[x]=1; for(int i=last[x], too;i;i=e[i].pre) if(!v[too=e[i].too]) { sum=size[too]; root=0; getroot(too, 0); dfs(root); } } int main() { read(n); read(Q); for(int i=1;i<n;i++) read(x), read(y), read(z), add(x, y, z), add(y, x, z); for(int i=1;i<=Q;i++) read(len[i]); sum=n; mx[0]=inf; vis[0]=1; getroot(1, 0); dfs(root); for(int i=1;i<=Q;i++) printf("%s\n", (ans[i] || !len[i])?"Yes":"No"); return 0; }
例4:bzoj3697: 采药人的路径
设$f[i][1]$为当前子树长度为i的路径有休息点的方案数,$f[i][0]$为当前子树长度为i路径有休息点的方案数,$g[i][0/1]$表示搜过的子树,其他同理。
统计子树间的答案就是$g[j][0]*f[-j][1]+g[j][1]*f[-j][0]+g[j][1]*f[-j][1]$,统计重心为端点的答案就是$g[0][0]*f[0][0]+f[0][1]$。
调了好久,不熟悉啊T T
#include<iostream> #include<cstring> #include<cstdlib> #include<cstdio> #define ll long long using namespace std; const int maxn=200010, inf=1e9; struct poi{int too, dis, pre;}e[maxn<<1]; int n, x, y, z, tot, sum, mxdep, root; int size[maxn], mx[maxn], d[maxn], dep[maxn], last[maxn], cnt[maxn]; ll ans, f[maxn][2], g[maxn][2]; bool v[maxn]; inline void read(int &k) { int f=1; k=0; char c=getchar(); while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar(); while(c<='9' && c>='0') k=k*10+c-'0', c=getchar(); k*=f; } inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;} void getroot(int x, int fa) { size[x]=1; mx[x]=0; for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa && !v[too]) { getroot(too, x); size[x]+=size[too]; mx[x]=max(mx[x], size[too]); } mx[x]=max(mx[x], sum-size[x]); if(mx[x]<mx[root]) root=x; } void update(int x, int fa) { if(cnt[d[x]+n]) f[d[x]+n][1]++; else f[d[x]+n][0]++; cnt[d[x]+n]++; mxdep=max(mxdep, dep[x]); for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa && !v[too]) { d[too]=d[x]+e[i].dis; dep[too]=dep[x]+1; update(too, x); } cnt[d[x]+n]--; } void calc(int x) { int mxd=0; for(int i=last[x], too;i;i=e[i].pre) if(!v[too=e[i].too]) { d[too]=e[i].dis; dep[too]=1; mxdep=1; update(too, 0); mxd=max(mxd, mxdep); ans+=g[n][0]*f[n][0]+f[n][1]; for(int j=-mxdep;j<=mxdep;j++) ans+=g[j+n][0]*f[n-j][1]+g[j+n][1]*f[n-j][0]+g[j+n][1]*f[n-j][1]; for(int j=-mxdep;j<=mxdep;j++) g[j+n][1]+=f[j+n][1], g[j+n][0]+=f[j+n][0], f[j+n][1]=f[j+n][0]=0; } for(int i=n-mxd;i<=n+mxd;i++) g[i][0]=g[i][1]=0; } void dfs(int x) { v[x]=1; calc(x); for(int i=last[x], too;i;i=e[i].pre) if(!v[too=e[i].too]) { root=0; sum=size[too]; getroot(too, 0); dfs(root); } } int main() { read(n); for(int i=1;i<n;i++) read(x), read(y), read(z), add(x, y, z?1:-1), add(y, x, z?1:-1); mx[0]=inf; sum=n; getroot(1, 0); dfs(root); printf("%lld\n", ans); return 0; }
来源:https://www.cnblogs.com/Sakits/p/8328707.html