没错...我就是要讲点分治。这个东西原本学过的,当时学得不好...今天模拟赛又考这个东西结果写不出来。
于是博主专门又去学了学这个东西,这次绝对要搞懂了...【复赛倒计时:11天】
——正片开始——
点分是一个用来解决树上路径问题、距离问题的算法。说直接点其实就是分治思想在树上的体现和应用。
首先是分治的方法,既然是树上问题,自然就是对树进行分治。
我们对于一棵树,选定一个点作为它的根,将这个点与其子树拆开,然后对它的子树也进行相同的操作,然后利用子树统计答案。一般来说,点分更像一种思想,而不是一个算法,当然你也可以把它当算法来学。
关于怎么选根来分树,其他dalao的博客已经讲得非常清楚仔细了,每次选定一棵树的重心,然后分它,这样可以做到
O(nlogn)的优秀时间复杂度。
关于求重心,做法就是一个size统计。
这里还是介绍一下吧(很多博客都只贴一个代码...):
对于一个节点x,若我们把其删除,原来的树就会变成若干部分,我们设maxsubtree(x)表示删除x后剩下的最大子树的大小,若我们找到一个x,使得maxsubtree(x)最小,x就是这棵树的重心。
给出求重心的代码:
void getroot(int x,int fa){ siz[x]=1;son[x]=0; for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(y==fa||vis[y])continue; getroot(y,x); siz[x]+=siz[y]; if(siz[y]>son[x])son[x]=siz[y]; } if(Siz-siz[x]>son[x])son[x]=Siz-siz[x]; if(son[x]<maxx)maxx=son[x],root=x; }
于是我们就知道怎么拆树了,后面的东西就不难了。
update: 19.11.5 对于上面的代码加一些解释,siz表示当前在求重心的这一棵树的大小,root用来记录重心
我们来讲如何统计信息
首先我们要统计与每一次找到的重心有关的路径,我们用dis[x]表示目标点(重心)到x的距离。
给出getdis的代码:
void getdis(int x,int fa,int d){//d表示x到目标点的距离 dis[++top]=d;//我们只需要知道每一个点到目标点的距离就行,不用知道这个点是哪个 for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(y==fa||vis[y])continue; getdis(y,x,d+val(i)); } }
有了dis数组后,我们就可以很轻松的获得路径的长度了。
比如,我们已知x到重心的距离是m,y到重心的距离是n,那x到y的距离就是m+n,可能细心的读者已经发现锅了。
如果x到y的路径都在一棵子树内,我们就会有一段距离被重复计算了,这样我们得到的路径就是不对的。
给一张图理解一下:
如图,蓝色的代表x到重心的路径,红色的代表y到重心的路径,我们可以得到dis[x]=3,dis[y]=2。
如果按照前面说的计算方式,x到y的路径长度应该是5了,但是并不是,我们的路径长度是3。
原因就是,绿色的那一段,我们根本不走,我们不经过它。
于是我们要解决这种不合法路径。知道所有路径,又知道不合法路径,利用容斥原理,我们可以得到:
合法路径=所有路径 - 不合法路径。
这样我们divide的代码就出来了:
void divide(int x){ vis[x]=1;//保证我们每次divide的x都是当前这棵树的重心,所以标记x已经divide过了 solve(x,0,1);//计算这棵树以x为重心的所有路径答案 int totsiz=Siz;//这棵树的大小 for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(vis[y])continue; solve(y,val(i),0);//求不合法路径 maxx=inf,root=0;//初始化 Siz=siz[y]>siz[x]?totsiz-siz[x]:siz[y];//更新siz getroot(y,0);//求出以y为根的子树 divide(root);//以y为根的子树分治下去 } }
我们看一看去除不合法路径的代码:
solve(y,val(i),-1);
思考发现:
所有不合法路径都是在同一棵子树中的路径,我们要减去它。
先往下看完solve再回来看这里。
我们进入到solve中,首先是getdis,以x为目标点获取dis。但是我们要获取的是距离为val(i)的dis。
这就使得dis[y]=val(i),所以以y为根的子树中的所有dis值就等于dis[y]+它们到y的距离,然后因为dis[y]是x到y的距离,
所以我们就求出了以y为根的子树中所有点到x的距离。
实际一点的理解就是,y到目标点的距离是dis[y]=val(i),y的子树中的点到x的距离就是它们到y的距离+dis[y],所以跑一遍可以求出y的子树中所有点到x的距离。
然后就是solve函数了:
我们这里的solve以模板题:【模板】点分治1 为例,不同题目我们solve的东西不同。
首先肯定可以想到一个O(n^2)的做法的,确实可以水过去这一题。
但是我毕竟是在写博客...所以,O(nlogn)做法奉上...
我们这么想,我们需要统计路径长为k的点对个数,那我们只要确定了一个点x,另一个点的dis[y]就应该是k-dis[x],我们只需要二分找k-dis[x]就行了。
void solve(int x,int d,int avl){//avl(available)表示这次计算得到的答案是否合法 top=0;//清空dis数组 getdis(x,0,d);//获取到当前这棵树到x的距离为d的所有dis int cnt=0; sort(dis+1,dis+top+1);//排好序准备二分 dis[0]=-1;//第一个dis设置为奇怪的数方便下面比较 for(int i=1;i<=top;i++){//把所有距离相同的点放进一个桶里面方便操作 if(dis[i]==dis[i-1]) bucket[cnt].amount++;//原来桶的个数+1 else bucket[++cnt].dis=dis[i],bucket[cnt].amount=1;//新开一个桶 } for(int i=1;i<=m;i++){ if(query[i]%2==0)//如果k是偶数的话,我们单独考虑一下距离为k/2那些点,它们可以互相配对形成长为k的路径 for(int j=1;j<=cnt;j++) if(bucket[j].dis==query[i]/2)//如果距离是k/2 ans[i]+=(bucket[j].amount-1)*bucket[j].amount/2*avl; //组合计数,假设我们有x个距离为k/2的点,就有(x-1)*x/2个点对距离为k,也就是我们可以配出这么多个不同点对 //其实就是C(x,2)->x!/((x-2)!*2)->(x-1)*x/2 for(int j=1;j<=cnt&&bucket[j].dis<query[i]/2;j++){ //接着枚举<k/2的距离,然后我们二分找>2的距离配对,避免重复(点对(u,v)和(v,u)是等价的),等于k/2的我们前面算过了,所以所有情况都考虑到了 int l=j+1,r=cnt; while(l<=r){ int mid=(l++r)>>1; if(bucket[j].dis+bucket[mid].dis==query[i]){ ans[i]+=bucket[j].amount*bucket[mid].amount*avl; //组合计数记录答案,假设我们有x个距离为m的点,y个距离为k-m的点,我们就有x*y个不同的点对(分类相乘) break;//这一轮二分完了,下一轮 } if(bucket[j].dis+bucket[mid].dis>query[i])r=mid-1;//大了,往小的二分 else l=mid+1;//小了,往大的二分 } } } }
这么详细都看不懂我就教不了了...
接下来就给出所有代码吧...(我知道你们只想看这个/doge)
#include<bits/stdc++.h> #define N 100010 #define lint long long #define inf 0x7fffffff using namespace std; int vis[N],son[N],Siz,maxx,siz[N]; int root,head[N],tot,n,m,dis[N],top,query[N],ans[N]; lint k; struct Bucket{ int dis,amount; }bucket[N]; struct Edge{ int nxt,to,val; #define nxt(x) e[x].nxt #define to(x) e[x].to #define val(x) e[x].val }e[N<<1]; inline int read(){ int data=0,w=1;char ch=0; while(ch!='-' && (ch<'0'||ch>'9'))ch=getchar(); if(ch=='-')w=-1,ch=getchar(); while(ch>='0' && ch<='9')data=data*10+ch-'0',ch=getchar(); return data*w; } inline void addedge(int f,int t,int val){ nxt(++tot)=head[f];to(tot)=t;val(tot)=val;head[f]=tot; } void getroot(int x,int fa){ siz[x]=1;son[x]=0; for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(y==fa||vis[y])continue; getroot(y,x); siz[x]+=siz[y]; if(siz[y]>son[x])son[x]=siz[y]; } if(Siz-siz[x]>son[x])son[x]=Siz-siz[x]; if(son[x]<maxx)maxx=son[x],root=x; } void getdis(int x,int fa,int d){ dis[++top]=d; for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(y==fa||vis[y])continue; getdis(y,x,d+val(i)); } } void solve(int rt,int d,int avl){//avl(available)表示这次计算得到的答案是否合法 top=0;//清空dis数组 getdis(rt,0,d);//获取到当前这棵树的rt的距离为d的所有dis int cnt=0; sort(dis+1,dis+top+1);//排好序准备二分 dis[0]=-1;//第一个dis设置为奇怪的数方便下面比较 for(int i=1;i<=top;i++){//把所有距离相同的点放进一个桶里面方便操作 if(dis[i]==dis[i-1]) bucket[cnt].amount++;//原来桶的个数+1 else bucket[++cnt].dis=dis[i],bucket[cnt].amount=1;//新开一个桶 } for(int i=1;i<=m;i++){ if(query[i]%2==0)//如果k是偶数的话,我们单独考虑一下距离为k/2那些点,它们可以互相配对形成长为k的路径 for(int j=1;j<=cnt;j++) if(bucket[j].dis==query[i]/2)//如果距离是k/2 ans[i]+=(bucket[j].amount-1)*bucket[j].amount/2*avl; //组合计数,假设我们有x个距离为k/2的点,就有(x-1)*x/2个点对距离为k,也就是我们可以配出这么多个不同点对 //其实就是C(x,2)->x!/((x-2)!*2)->(x-1)*x/2 for(int j=1;j<=cnt&&bucket[j].dis<query[i]/2;j++){//接着枚举<k/2的距离,然后我们二分找>2的距离配对 int l=j+1,r=cnt; while(l<=r){ int mid=(l+r)>>1; if(bucket[j].dis+bucket[mid].dis==query[i]){ ans[i]+=bucket[j].amount*bucket[mid].amount*avl; //组合计数记录答案,假设我们有x个距离为m的点,y个距离为k-m的点,我们就有x*y个不同的点对(分类相乘) break;//这一轮二分完了,下一轮 } if(bucket[j].dis+bucket[mid].dis>query[i])r=mid-1;//大了,往小的二分 else l=mid+1;//小了,往大的二分 } } } } void divide(int x){ vis[x]=1; solve(x,0,1);//合法的算进去 int totsiz=Siz; for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(vis[y])continue; solve(y,val(i),-1);//不合法的算出来减掉 maxx=inf,root=0; Siz=siz[y]>siz[x]?totsiz-siz[x]:siz[y]; getroot(y,0); divide(root); } } int main(){ n=read();m=read(); for(int i=1;i<n;i++){ int x=read(),y=read(),z=read(); addedge(x,y,z);addedge(y,x,z); } for(int i=1;i<=m;i++)query[i]=read(); maxx=inf;root=0;Siz=n; getroot(1,0); divide(root); for(int i=1;i<=m;i++){ if(ans[i]>0)puts("AYE"); else puts("NAY"); } return 0; }
这份代码不仅求出了是否有路径长为k的答案存在,还求出了路径长为k的点对个数。
不需要求个数的话你就改一改solve,这样常数会小一点,跑得更快,但是这一题我更想给你们讲讲思路,学会举一反三...最后还是那句话,我个人并不认为点分是一种算法,它更多体现的是分治的思想在树上的应用。
其实你会发现,很多高端的数据结构、算法之类的,都是由基础的算法思想衍生出来的泛用性更强的东西。
懂了基础的算法思想,你不但可以轻松的学会各种高阶算法,甚至可以自己造出解题的算法。
比如图论里面的dijkstra最短路算法,不就是贪心吗?再比如线段树,不就是分治吗?
一些看起来很高级,听起来很难的东西,只要你弄明白了其中的本质,你在感叹发明者的智慧同时,自己也就收获到了其中蕴含的知识和基础思想的应用方法。学习不是死学...只有弄明白了它的工作原理和方式,你才算是掌握了它,仅仅是可以用它来做题,那题目变变形,你就一脸懵了。