【算法】点分治初探

爱⌒轻易说出口 提交于 2019-12-27 07:23:17

参考:http://blog.csdn.net/nixinyis/article/details/65445466

【简介】

  点分治是一类用于处理树上路径点权统计问题的算法,其利用重心的性质降低复杂度。

【什么是重心】

  某个其所有的子树中最大的子树节点数最少的点被称为重心,删去重心后,生成的多棵树尽可能平衡。

【重心的性质】

  ①重心其所有子树的大小都不超过$\frac{n}{2}$。

  ②树中所有点到某个点的距离和中,到树的重心的距离和是最小的,如果有两个重心,那么到它们的距离和相同。

  ③把两棵树通过两个点相连得到一棵新的树,新的树的重心必定在连接两棵树的重心的路径上。

  ④一棵树添加或删除一个节点,树的重心最多会移动一条边的位置。

  点分治的复杂度基于重心的第一个性质。

【点分治】

  点分治是对于每一棵子树,都求出它的重心,并且以重心为根跑一遍这棵子树并统计经过重心的路径,因为我们知道重心所有子树的大小都只有原树的一半,也就是我们这么做最多只会递归$logn$层,若一层的复杂度$O(f(n))$,则总的时间复杂度为$O(f(n)logn)$。

  接下来以bzoj1468: Tree为例题讲一下点分治。

  对于这道题,显然用上方说的方法,对于每一个子树求出dep,排序后两端指针往中间靠拢统计即可。但是可能会统计重复,如下图,如果红色路径加上父亲边的两倍还小于K,那么也会被根节点统计,所以我们还得删去加上这条父亲边之后的贡献。

【点分治的流程】(以bzoj1468为标准,可能部分内容因题目而异)

  点分治由getroot(求重心),calc(计算贡献),getdep(求某个子树以重心为根的所有点的深度),dfs构成。

  首先是getroot部分,求重心就不多提了,只要找到一个点满足最大的子树最小即可。

  getroot的一些注意事项:getroot之前把root清零,把sum改成子树大小,getroot的时候不访问当前点的父亲或者是dfs已经访问过的点,记得搜到每一个点的时候都要先把mx清零。

  getdep部分,calc的时候需要先对这棵子树先getdep,用d数组求出以子树根为根的子树的所有点的深度,并且再用一个dep数组把子树里所有点的深度都塞进去,注意dep只存这棵子树的点的深度,而d是以节点编号为下标的。同样不访问当前点的父亲或者是dfs已经访问过的点。

  calc部分,形参有子树根和init,init表示子树根的初始深度,用于上面所说的去重。将子树根初始深度设好,然后getdep,接着就可以处理这个子树的信息了。

  最后是dfs部分,每次先对当前点calc(x, 0)计算贡献,然后对每一个son,先把答案去掉calc(son, e[i].dis)的贡献,然后对子树getroot求出重心,接着搜子树的重心即可。

代码如下:

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn=40010, inf=1e9;
struct poi{int too, dis, pre;}e[maxn<<1];
int n, K, x, y, z, sum, ans, tot, root, cnt, tott;
int size[maxn], mx[maxn], d[maxn], dep[maxn], last[maxn];
bool v[maxn];
inline void read(int &k)
{
    int f=1; k=0; char c=getchar();
    while(c<'0' || c>'9') c=='-'&&(f=-1), c=getchar();
    while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
    k*=f;   
} 
inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
void getroot(int x, int fa)
{
    size[x]=1; mx[x]=0;
    for(int i=last[x], too;i;i=e[i].pre)
    if(!v[too=e[i].too] && too!=fa)
    {
        getroot(too, x);
        size[x]+=size[too];
        mx[x]=max(mx[x], size[too]);
    }
    mx[x]=max(mx[x], sum-size[x]);
    if(mx[x]<mx[root]) root=x;
}
void getdep(int x, int fa)
{
    dep[++cnt]=d[x];
    for(int i=last[x], too;i;i=e[i].pre)
    if(!v[too=e[i].too] && too!=fa)
    {
        d[too]=d[x]+e[i].dis;
        getdep(too, x);
    }
}
int calc(int x, int init)
{
    d[x]=init; cnt=0;
    getdep(x, 0);
    sort(dep+1, dep+cnt+1);
    int l=1, r=cnt, sum=0;
    while(l<r)
    if(dep[l]+dep[r]<=K) sum+=r-l, l++;
    else r--;   
    return sum;
}
void dfs(int x)
{
    ans+=calc(x, 0); v[x]=1; 
    for(int i=last[x], too;i;i=e[i].pre)
    if(!v[too=e[i].too])
    {
        ans-=calc(too, e[i].dis);
        sum=size[too]; root=0; 
        getroot(too, 0); 
        dfs(root);
    }
}
int main()
{
    read(n);
    for(int i=1;i<n;i++) 
    read(x), read(y), read(z), add(x, y, z), add(y, x, z);
    read(K); sum=n; mx[0]=inf; getroot(1, 0); dfs(root);
    printf("%d\n", ans);
}
View Code

  还有一种写法是对每个重心的子树单独搜,对每个子树统计与搜过子树的方案数,这样子统计得到的路径必定经过重心。

【例题】

例1:bzoj2152: 聪聪可可

  这题一眼做法是$O(n)$的,但这里只说点分治的做法。

  和上方的例题差不多,只是改成每次求子树里深度%3的个数而已,最后答案就是$sum[0]*sum[0]+sum[1]*sum[2]$。

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#define MOD(x) ((x)>=3?(x)-3:(x))
using namespace std;
const int maxn=500010, inf=1e9;
struct poi{int too, dis, pre;}e[maxn<<1];
int n, x, y, z, tot, sum, root, ans;
int last[maxn], mx[maxn], size[maxn], cnt[3], d[maxn];
bool v[maxn];
inline void read(int &k)
{
    int f=1; k=0; char c=getchar();
    while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar();
    while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
    k*=f;
}
inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
inline int gcd(int a, int b){return b?gcd(b, a%b):a;}
void getroot(int x, int fa)
{
    mx[x]=0; size[x]=1; 
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa && !v[too])
    {
        getroot(too, x);
        size[x]+=size[too];
        mx[x]=max(mx[x], size[too]);
    }
    mx[x]=max(sum-size[x], mx[x]);
    if(mx[x]<mx[root]) root=x;
}
void getdep(int x, int fa)    
{
    cnt[d[x]]++;
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa && !v[too])
    {
        d[too]=MOD(d[x]+e[i].dis);
        getdep(too, x);
    }
}
int calc(int x, int init)
{
    d[x]=init; 
    cnt[0]=cnt[1]=cnt[2]=0; 
    getdep(x, 0);
    return cnt[0]*cnt[0]+(cnt[1]*cnt[2]<<1);
}
void dfs(int x)
{
    ans+=calc(x, 0); v[x]=1; 
    for(int i=last[x], too;i;i=e[i].pre)
    if(!v[too=e[i].too])
    {
        ans-=calc(too, e[i].dis);
        sum=size[too]; root=0;
        getroot(too, 0);
        dfs(root);
    }
}
int main()
{
    read(n);
    for(int i=1;i<n;i++) 
    read(x), read(y), read(z), z%=3, add(x, y, z), add(y, x, z);
    sum=n; mx[0]=inf; getroot(1, 0); dfs(root);
    printf("%d/%d\n", ans/gcd(n*n, ans), n*n/gcd(n*n, ans));
}
View Code

例2:bzoj2599: [IOI2011]Race

  这题对于点分治的初学者来说很容易出错(比如我T T)

  因为对于每一个子树我们只能统计经过其重心的路径,而这题是求最小边长,按上面两题的做法做就不可行了。对于每一个子树查询其答案的时候,重心的每一个子树先进行更新答案,再记录其深度的值,这样求得的路径就必定经过重心了,然后再统计重心为端点的答案。

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
using namespace std;
const int maxn=1000010, inf=1e9;
struct poi{int too, dis, pre;}e[maxn<<1];
int n, x, y, z, tot, sum, root, ans, cnt, K;
int last[maxn], mx[maxn], size[maxn], dep[maxn], d[maxn], ecnt[maxn], mn[maxn];
bool v[maxn];
inline void read(int &k)
{
    int f=1; k=0; char c=getchar();
    while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar();
    while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
    k*=f;
}
inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
void getroot(int x, int fa)
{
    mx[x]=0; size[x]=1; 
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa && !v[too])
    {
        getroot(too, x);
        size[x]+=size[too];
        mx[x]=max(mx[x], size[too]);
    }
    mx[x]=max(sum-size[x], mx[x]);
    if(mx[x]<mx[root]) root=x;
}
void getans(int x, int fa)    
{
    dep[++cnt]=d[x]; 
    if(d[x]<=K) ans=min(ans, mn[K-d[x]]+ecnt[x]);
    if(d[x]==K) ans=min(ans, ecnt[x]);
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa && !v[too])
    {
        d[too]=d[x]+e[i].dis;
        ecnt[too]=ecnt[x]+1;
        getans(too, x);
    }
}
void update(int x, int fa)
{
    if(d[x]<=K) mn[d[x]]=min(mn[d[x]], ecnt[x]);
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa && !v[too]) update(too, x);
}
void calc(int x)
{
    d[x]=ecnt[x]=cnt=0;
    for(int i=last[x], too;i;i=e[i].pre)
    if(!v[too=e[i].too]) d[too]=e[i].dis, ecnt[too]=1, getans(too, x), update(too, x);
    for(int i=1;i<=cnt;i++)
    if(K>=dep[i]) mn[dep[i]]=inf;
}
void dfs(int x)
{
    calc(x); v[x]=1; 
    for(int i=last[x], too;i;i=e[i].pre)
    if(!v[too=e[i].too])
    {
        sum=size[too]; root=0;
        getroot(too, 0);
        dfs(root);
    }
}
int main()
{
    read(n); read(K);
    for(int i=1;i<n;i++) 
    read(x), read(y), read(z), x++, y++, add(x, y, z), add(y, x, z);
    memset(mn, 32, sizeof(mn)); ans=inf;
    sum=n; mx[0]=inf; getroot(1, 0); dfs(root);
    printf("%d\n", ans>n?-1:ans);
    return 0;
}
View Code

例3:bzoj1316: 树上的询问

  这题就是在点分的时候扫一遍所有询问,查询是否有这个len即可,有个坑是len==0的时候答案为Yes。

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
using namespace std;
const int maxn=1000010, inf=1e9, maxl=1000010;
struct poi{int too, dis, pre;}e[maxn<<1];
int n, x, y, z, tot, sum, root, cnt, Q;
int last[maxn], mx[maxn], size[maxn], dep[maxn], d[maxn], len[maxn], ans[maxn];
bool v[maxn], vis[maxn];
inline void read(int &k)
{
    int f=1; k=0; char c=getchar();
    while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar();
    while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
    k*=f;
}
inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
void getroot(int x, int fa)
{
    mx[x]=0; size[x]=1; 
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa && !v[too])
    {
        getroot(too, x);
        size[x]+=size[too];
        mx[x]=max(mx[x], size[too]);
    }
    mx[x]=max(sum-size[x], mx[x]);
    if(mx[x]<mx[root]) root=x;
}
void getans(int x, int fa)    
{
    dep[++cnt]=d[x]; 
    for(int i=1;i<=Q;i++)
    if(len[i]>=d[x])
    if(vis[len[i]-d[x]]) ans[i]=1;
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa && !v[too])
    d[too]=d[x]+e[i].dis, getans(too, x);
}
void update(int x, int fa)
{
    if(d[x]<=maxl) vis[d[x]]=1;
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa && !v[too]) update(too, x);
}
void calc(int x)
{
    d[x]=cnt=0;
    for(int i=last[x], too;i;i=e[i].pre)
    if(!v[too=e[i].too]) d[too]=e[i].dis, getans(too, x), update(too, x);
    for(int i=1;i<=cnt;i++)
    if(maxl>=dep[i]) vis[dep[i]]=0;
}
void dfs(int x)
{
    calc(x); v[x]=1; 
    for(int i=last[x], too;i;i=e[i].pre)
    if(!v[too=e[i].too])
    {
        sum=size[too]; root=0;
        getroot(too, 0);
        dfs(root);
    }
}
int main()
{
    read(n); read(Q);
    for(int i=1;i<n;i++) 
    read(x), read(y), read(z), add(x, y, z), add(y, x, z);
    for(int i=1;i<=Q;i++) read(len[i]);
    sum=n; mx[0]=inf; vis[0]=1; getroot(1, 0); dfs(root);
    for(int i=1;i<=Q;i++) printf("%s\n", (ans[i] || !len[i])?"Yes":"No");
    return 0;
}
View Code

例4:bzoj3697: 采药人的路径

  设$f[i][1]$为当前子树长度为i的路径有休息点的方案数,$f[i][0]$为当前子树长度为i路径有休息点的方案数,$g[i][0/1]$表示搜过的子树,其他同理。

  统计子树间的答案就是$g[j][0]*f[-j][1]+g[j][1]*f[-j][0]+g[j][1]*f[-j][1]$,统计重心为端点的答案就是$g[0][0]*f[0][0]+f[0][1]$。

  调了好久,不熟悉啊T T

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#define ll long long
using namespace std;
const int maxn=200010, inf=1e9;
struct poi{int too, dis, pre;}e[maxn<<1];
int n, x, y, z, tot, sum, mxdep, root;
int size[maxn], mx[maxn], d[maxn], dep[maxn], last[maxn], cnt[maxn];
ll ans, f[maxn][2], g[maxn][2];
bool v[maxn];
inline void read(int &k)
{
    int f=1; k=0; char c=getchar();
    while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar();
    while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
    k*=f;
}
inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
void getroot(int x, int fa)
{
    size[x]=1; mx[x]=0;
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa && !v[too])
    {
        getroot(too, x);
        size[x]+=size[too];
        mx[x]=max(mx[x], size[too]);
    }
    mx[x]=max(mx[x], sum-size[x]);
    if(mx[x]<mx[root]) root=x;
}
void update(int x, int fa)
{
    if(cnt[d[x]+n]) f[d[x]+n][1]++;
    else f[d[x]+n][0]++;
    cnt[d[x]+n]++; mxdep=max(mxdep, dep[x]);
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa && !v[too])
    {
        d[too]=d[x]+e[i].dis;
        dep[too]=dep[x]+1;
        update(too, x);
    }
    cnt[d[x]+n]--;
}
void calc(int x)
{
    int mxd=0;
    for(int i=last[x], too;i;i=e[i].pre)
    if(!v[too=e[i].too]) 
    {
        d[too]=e[i].dis; dep[too]=1; mxdep=1; update(too, 0);
        mxd=max(mxd, mxdep); ans+=g[n][0]*f[n][0]+f[n][1];
        for(int j=-mxdep;j<=mxdep;j++)
        ans+=g[j+n][0]*f[n-j][1]+g[j+n][1]*f[n-j][0]+g[j+n][1]*f[n-j][1];
        for(int j=-mxdep;j<=mxdep;j++) 
        g[j+n][1]+=f[j+n][1], g[j+n][0]+=f[j+n][0], f[j+n][1]=f[j+n][0]=0;
    }
    for(int i=n-mxd;i<=n+mxd;i++) g[i][0]=g[i][1]=0;
}
void dfs(int x)
{
    v[x]=1; calc(x); 
    for(int i=last[x], too;i;i=e[i].pre)
    if(!v[too=e[i].too])
    {
        root=0; sum=size[too];
        getroot(too, 0);
        dfs(root);
    }
}
int main()
{
    read(n);
    for(int i=1;i<n;i++) 
    read(x), read(y), read(z), add(x, y, z?1:-1), add(y, x, z?1:-1);
    mx[0]=inf; sum=n; getroot(1, 0); dfs(root);
    printf("%lld\n", ans);
    return 0;
}
View Code
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!