最近树上公共祖先详细图解

旧城冷巷雨未停 提交于 2020-04-29 17:01:12

一、定义

LCA(Least Common Ancestors),树上最近公共祖先,顾名思义,也就是说,对于节点u,v,设x=lca(u,v),则,u和v均在x的子树中,并且x的深度最小。

图画得太丑了

如这幅图中:lca(5,6)=2 ,lca(6,3)=1 ,lca(3,9)=9。

二、解法

(1)dfs序

何为dfs序?即为深度优先搜索遍历完这棵树后所获得的节点访问先后顺序。如上面那副图的dfs序:

求dfs序代码:

void dfs(int x,int depth)//遍历一遍图求dfs序 { int i; len++; ola[len]=x;dep[len]=depth;//ola数组用来存dfs序 vis[x]=len;//标记x并存下x在ola数组中出现的位置 for(i=first[x];i;i=next[i]) { if(!vis[v[i]]) { dfs(v[i],depth+1); len++; ola[len]=x; dep[len]=depth; } } }

我们可以发现,任何一个节点出现两次之间,一定包含了他的子树。于是我们在将dfs序与节点深度结合起来看:

我们再来找5和3的lca,可以由下一幅图知道1。

不难发现,我们要找的点即为5和3之间深度最小的点。因此我们只需要将dfs序和每个点在dfs序上的位置预处理出来以后,按RMQ问题的套路求解即可。实际上我们只需要记录每个节点首次出现的位置即可,因为如图所示,我们查询2和3的lca

按理说应该查询这一段的最小值

然而我们查询这一段也是同样的结果

因为这一段中

所有的节点都是以2根的子树的节点,深度不可能小于其lca,对结果影响。

dfs序求lca的原理解决了,接下来我们只需要套RMQ问题的模板就行了!使用st表时间复杂度为O(2Nlog(2N)+N+M) 其中,N为节点数,M为询问个数。若不知道st表是什么的同学也没有关系, 当然我们用线段树,但时间复杂度为O(N+(2N+M)log(2N))常数也挺大,一般不用他。

模板题

然而,对于这道题,N和M都高达5×10^5,dfs序算法会超时3个点。下面贴上作者巨丑无比的代码(由于是很久之前写的),大家轻喷。 (对于求lca,dfs序一般少用,大家可以直接跳过看下一种算法,但dfs序的思想很重要,一定要理解)


#include<cstdio>
#include<algorithm> using namespace std; const int MAXN=5e5+5; int ola[MAXN*2],dep[MAXN*2],first[MAXN],next[MAXN*2],u[MAXN*2],v[MAXN*2]; int vis[MAXN],st[MAXN*2][25][2],log[2*MAXN]; int n,m,len,p; void dfs(int x,int depth)//遍历一遍图求dfs序 { int i; len++; ola[len]=x;dep[len]=depth;//ola数组用来存dfs序 vis[x]=len;//标记x并存下x在ola数组中出现的位置 for(i=first[x];i;i=next[i]) { if(!vis[v[i]]) { dfs(v[i],depth+1); len++; ola[len]=x; dep[len]=depth; } } } void build()//存图 { int i,j,k; for(i=1;i<n;i++) scanf("%d%d",&u[i],&v[i]), next[i]=first[u[i]],first[u[i]]=i; for(i=n;i<n+n-1;i++) u[i]=v[i-n+1],v[i]=u[i-n+1], next[i]=first[u[i]],first[u[i]]=i; } int main() { int i,j,k; scanf("%d%d%d",&n,&m,&p); build();//建图 dfs(p,1);//求dfs序 log[0]=-1; for(i=1;i<=len;i++) log[i]=log[i>>1]+1,//预处理log数组 st[i][0][0]=dep[i],st[i][0][1]=ola[i]; for(j=1;j<=log[len];j++)//预处理st表 for(i=1;i+(1<<j)-1<=len;i++) { k=i+(1<<j-1); if(st[i][j-1][0]<st[k][j-1][0]) st[i][j][0]=st[i][j-1][0],st[i][j][1]=st[i][j-1][1]; else st[i][j][0]=st[k][j-1][0],st[i][j][1]=st[k][j-1][1]; }//st[..][..][0]存深度,st[..][..][1] 存点的编号 while(m--) { scanf("%d%d",&i,&j); i=vis[i]; j=vis[j]; if(i>j) k=i,i=j,j=k; k=log[j-i+1]; int l=j-(1<<k)+1; if(st[i][k][0]<st[l][k][0]) printf("%d\n",st[i][k][1]);//愉快输出答案 else printf("%d\n",st[l][k][1]);//愉快输出答案 } }

(2) 倍增

如果让你考虑暴力的算法求解,我们该怎样做呢?

对于u,v,先将深度较大的点往上跳,跳到和另一个点深度相同。然后将两个点同时上跳,直到两个点重合,那么现在我们便找到了u和v的lca,代码如下:

int lca(int x,int y) { if(dep[x]<dep[y])//我们默认x深度较大 swap(x,y);//若不是,那就需要交换 while(dep[x]>dep[y]) x=fa[x];//将x提到与y深度相同的位置 while(x!=y) x=fa[x],y=fa[y];//暴力求解lca return x; } 

如何考虑优化呢?毫无疑问,上面那种做法费时间就费在一步一步往上跳,若我们能一次性往上跳很长一段距离,那无疑就很好了。

还记得快速幂吗?我们可以吧a^13拆成a×a^4×a^8,当然我们也可以利用这个思想将往上跳13次换成往上跳8格,4格,1格,自然,我们将往上跳n次优化成了往上跳logn次。那么若何具体实现求lca呢?看下图:

求x和y(默认x、y深度相等),设i=3,令x和y往上跳2^i格,也就是8格,哦豁!跳到了0号节点上,于是我们令i--,再往上跳4格。

i--,在往上跳2格,哦豁!又跳到外面去了,i再减1,i=0,再往上跳1格。

x和y又跳到了同一个点上,不行,不能跳。此时,算法流程结束,x和y的父亲即为他们的lca

这次模拟的最后一步看似冗杂,但却是必不可缺的,下面上伪代码:

int lca(int x,int y) { if(dep[x]<dep[y])//我们默认x深度较大 swap(x,y);//若不是,那就需要交换 while(dep[x]>dep[y]) x=fa[x];//将x提到与y深度相同的位置 for(int i=20;i>=0;i--) if(x往上提2^i与y往上提2^i的节点不同) 将x和y分别往上跳2^i格; return fa[x];//最后x的父亲即为答案 } 

那么如何,实现将一个节点往上提2^i呢?外面定义一个fa[x][i]表示x号节点往上提2^i格后对应的节点,那么每次,我们只需要,令x=fa[x][i]就行了,说了等于没说。那么如何和得到这个fa数组呢?我们只需要得到这个dp方程式

fa[x][i]=fa[fa[x][i-1]][i-1] 把往上跳2^i格分为跳两次2^(i-1)格。

边界条件:fa[x][0]=father[x];

代码

void dfs(int x,int father) { dep[x]=dep[father]+1;//处理深度 fa[x][0]=father;//初始化边界条件 for(int i=1;i<=20;i++) fa[x][i]=fa[fa[x][i-1]][i-1];//dp for(int i=frt[x];i;i=nxt[i]) if(v[i]!=father) dfs(v[i],x);//遍历这棵树 }

查询代码:

int lca(int x,int y) { if(dep[x]<dep[y])//我们默认x深度较大 swap(x,y);//若不是,那就需要交换 int l=dep[x]-dep[y]; int i=0; while(l) { if(l&1) x=fa[x][i]; i++; l>>=1; }//利用快速幂的思想将x提至与y同一深度 if(x==y) return x;//特判 for(int i=20;i>=0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0];//最后x的父亲即为答案 } 

完整代码

#include<cstdio>
#include<iostream> using namespace std; const int N=5e5+10; int frt[N],nxt[2*N],v[2*N],dep[N],fa[N][21]; int h,p,n,m,tot; void add(int x,int y)//加边 { v[++tot]=y; nxt[tot]=frt[x];frt[x]=tot; } void dfs(int x,int father)//预处理fa { int i; dep[x]=dep[father]+1; fa[x][0]=father; for(i=1;i<=20;i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for(i=frt[x];i;i=nxt[i]) if(v[i]!=father) dfs(v[i],x); } int lca(int x,int y) { if(dep[x]<dep[y])//我们默认x深度较大 swap(x,y);//若不是,那就需要交换 int l=dep[x]-dep[y]; int i=0; while(l) { if(l&1) x=fa[x][i]; i++; l>>=1; }//利用快速幂的思想将x提至与y同一深度 if(x==y) return x;//特判 for(int i=20;i>=0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0];//最后x的父亲即为答案 } int main() { int a,b,i; scanf("%d%d%d",&n,&m,&p); for(i=1;i<n;i++) scanf("%d%d",&a,&b),add(a,b),add(b,a); dfs(p,0); while(m--) { scanf("%d%d",&a,&b); printf("%d\n",lca(a,b)); } }

这是和《信息学奥赛一本通》上面的代码差不多。常数太大,还有很多很暴力的地方,最慢一个点跑了700ms。

如何优化?

先来看这一句

for(i=1;i<=20;i++)
fa[x][i]=fa[fa[x][i-1]][i-1];

显然,我们不用一直循环到20,因为很多时候,到了后面根本不需要处理,因为已经跳到外面去了,因此我们可以将其改成:

for(int i=1;1<<i<dep[x];i++)
fa[x][i]=fa[fa[x][i-1][i-1];

其中,1<<i表示2^i。

还有这一句

for(int i=20;i>=0;i--)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];

显然也没必要从20开始减,实际上我们只需要从log(dep[x])开始减,由于cmath库中的log2运算太慢,我们使用O(n)的方法递推出log数组。

for(int i=1;i<=n;i++)
log[i]=log[i>>1]+1;//注意边界条件log[0]=-1;

因此倍增往上跳的时候只需这样即可:

for(int i=log[dep[x]];i>=0;i--)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];

完整代码:

#include<cstdio>
#include<iostream> using namespace std; const int N=5e5+10; int frt[N],nxt[2*N],v[2*N]; int dep[N],fa[N][20],log[N]; int h,p,n,m,tot; void add(int x,int y)//加边 { v[++tot]=y; nxt[tot]=frt[x];frt[x]=tot; } void dfs(int x,int father) { int i; dep[x]=dep[father]+1; fa[x][0]=father; for(i=1;1<<i<dep[x];i++)//优化1 fa[x][i]=fa[fa[x][i-1]][i-1]; for(i=frt[x];i;i=nxt[i]) if(v[i]!=father) dfs(v[i],x); } int lca(int x,int y) { if(dep[x]<dep[y]) swap(x,y); int l=dep[x]-dep[y],i=0; while(l)//快速幂优化 { if(l&1) x=fa[x][i]; i++; l>>=1; } if(x==y) return x;//特判 for(i=log[dep[x]];i>=0;i--)//优化2 if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } int main() { int a,b,i; scanf("%d%d%d",&n,&m,&p); log[0]=-1; for(i=1;i<n;i++) scanf("%d%d",&a,&b),log[i]=log[i>>1]+1,//O(n)方法递推出log[] add(a,b),add(b,a);//无向图,加双向边 log[n]=log[n>>1]+1; dfs(p,0);//预处理 while(m--) { scanf("%d%d",&a,&b); printf("%d\n",lca(a,b));//愉快输出答案~ } }

这样一来最慢的一个点只跑了500ms,两个简单的优化,便少跑了200ms,说明一些小的优化也是十分重要的!

(3)树链剖分法

顾名思义,树链剖分即为将树剖分成一条一条的链

那么如何来剖分呢?

算法流程

我们将树中的边分为轻边与重边,如下图所示,加粗的是重边,其余的是轻边。

如何判断一条边是轻边还是重边呢?

我们规定u节点的所有儿子节点v中,找出size(v)(即为以v为根节点的子树的大小)最大的v',那么边(u,v')是重边,其余的是轻边。

这样一来,可以得出轻重边的一些性质:

1:如果边(u,v)为轻边,那么 size(v)\leqsize(v)≤ 1\over221 size(u)size(u)

因为如果 size(v)>size(v)1\over221 size(u)size(u) ,那么size(v)一定是最大的,那么(u,v)就是重边。

2:从根节点到某一叶节点的路径上最多有 \log(n)log(n) 条轻边,因为根据性质一,每走过一条轻边,子树的节点数至少减少一半,因此最多走过 \log(n)log(n) 条轻边便走到了叶节点。

3:我们将一段连续的重边称为重路径,很显然,重路径的起点与终点也与轻边相连,因此重路径的数量也至多有 \log(n)log(n) 条。

在树链剖分的过程中需要计算以下几个值:

fa[x]:x节点的父亲。

dep[x]:x节点所处的深度。

size[x]:以x节点为根的子树的大小。

top[x]:x节点所处的重路径的顶部节点。

son[x]:x节点的重儿子,即(x,son[x])为一条重边。

这5个值可以用两遍dfs完成,第一遍dfs求前四个值,代码如下:

void dfs1(int x) { int tmp=0; size[x]=1; for(int i=frt[x];i;i=nxt[i]) if(v[i]!=fa[x]) { fa[v[i]]=x; dep[v[i]]=dep[x]+1; dfs1(v[i]); size[x]+=size[v[i]]; if(size[v[i]]>tmp) tmp=size[v[i]],son[x]=v[i]; } }

第二遍dfs:

void dfs2(int x) { if(son[x]) { top[son[x]]=top[x]; dfs2(son[x]); } for(int i=frt[x];i;i=nxt[i]) if(v[i]!=fa[x]&&v[i]!=son[x]) { top[v[i]]=v[i]; dfs2(v[i]); } }

那么如何求lca呢

假如u和v的top相同,说明u和v在同一条重路径上,那么此时lca(u,v)一定是u,v中深度较小的。若u和v的top不同,那么lca(u,v)有可能在其中一个节点的重路经上,也有可能在别的重路径上,但显然不可能在top深度较大的重路径上,于是我们挑出u,v中top深度较大的点,假设是u,我们将u跳到fa[top[u]]的位置,再来看u,v的top值是否相同,如果不同,再这样往复循环,直到u,v的top相同时,便求得了u,v的lca,为u,v中深度较小的那个点。

代码如下:

while(top[x]!=top[y])
{
   if(dep[top[x]]>dep[top[y]])
   x=fa[top[x]];
   else y=fa[top[y]];
}
printf("%d\n",dep[x]<dep[y]?x:y);

完整代码:

#include<cstdio>
#define N 500010 int nxt[N<<1],v[N<<1],frt[N]; int fa[N],top[N],size[N],son[N],dep[N]; int n,m,tot,p; inline void add(int x,int y) { v[++tot]=y; nxt[tot]=frt[x];frt[x]=tot; } void dfs1(int x) { int tmp=0; size[x]=1; for(int i=frt[x];i;i=nxt[i]) if(v[i]!=fa[x]) { fa[v[i]]=x; dep[v[i]]=dep[x]+1; dfs1(v[i]); size[x]+=size[v[i]]; if(size[v[i]]>tmp) tmp=size[v[i]],son[x]=v[i]; } } void dfs2(int x) { if(son[x]) { top[son[x]]=top[x]; dfs2(son[x]); } for(int i=frt[x];i;i=nxt[i]) if(v[i]!=fa[x]&&v[i]!=son[x]) { top[v[i]]=v[i]; dfs2(v[i]); } } int main() { scanf("%d%d%d",&n,&m,&p); for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); add(x,y);add(y,x); } dep[p]=1; dfs1(p); dfs2(p); while(m--) { int x,y; scanf("%d%d",&x,&y); while(top[x]!=top[y]) { if(dep[top[x]]>dep[top[y]]) x=fa[top[x]]; else y=fa[top[y]]; } printf("%d\n",dep[x]<dep[y]?x:y); } }

这样一来,最慢的一个点只跑了360ms,更快了!

(4)最高效解法:Tarjan

Tarjan是一种离线的线性时间复杂度的算法。

算法流程

以下图为例:

查询:

4 5

6 7

5 8

10 12

14 18

16 11

7 17

我们用dfs遍历这颗树,首先来到一号节点,令vis[1]=1,令f[1]=1,表示节点已被访问过。进入下一个节点2,令vis[2]=1,f[2]=2,来到3号节点,vis[3]=1,f[3]=3,来到4,f[4]=4,vis[4]=1。

这时我们发现4没有子节点了,令f[4]=自己的父节点3,看一看有没有和4号有关的查询,有一组4-5,但vis[5]=0,不管。回溯至3号节点,同样,三号节点的子节点也都访问过了,并且没有与3相关的hui查询,令f[3]=2,回到2,进入5,令vis[5]=1,f[5]=5,到6,f[6]=6,vis[6]=1,回到5,vis[6]=5。进入7,f[7]=7,到8,f[8]=8。

8没有自节点了,看一看有么有与8相关的查询,有:5-8,vis[5]=1,那么lca(5,8)就可以用并查集得到,为find(5)=5。令f[8]=7,回到7。

7的子节点也都访问过了,发现一组6-7的查询,并且vis[6]=1,那么lca(6,7)=find(6)=5。还有一组7-17的查询,但vis[17]=0,不管。令f[7]=5,回到5。

这时我们发现5的儿子也都走完了,有查询5-8,且vis[8]=1,lca(5,8)=find(8)=5。令f[5]=2,回到2。令f[2]=1,回到1,进入9,f[9]=9,f[10]=10,f[13]=13,f[14]=14。到14号节点时,有一组14-18的查询,但vis[18]=0,不管。令f[14]=13,回到13,进入17,f[17]=17。

这时发现一组查询7-17并且vis[7]=1,那么lca(7,17)=find(7)=1。令f[17]=13,回到13,进入18,vis[18]=18。

发现一组18-14的查询,则lca(18,14)=find(14)=13。令f[18]=13,回到13,令f[13]=10,回到10,发现一组10-12的查询,但vis[12]=0,不管。令f[10]=9,回到9,f[9]=1,回到1,进入11,f[11]=11,f[12]=12。

这时发现一组12-10的查询,vis[10]=1,lca(10,12)=find(10)=1。令f[12]=11,回到11,进入15,f[15]=15,进入16,f[16]=16。

发现查询16-11,lca(16,11)=find(11)=11。令f[16]=15。回到15,f[15]=11,回到1,算法结束。

代码实现

#include<cstdio>
#define N 500010 int frt[N],v[N<<1],nxt[N<<1],head[N],vis[N],f[N]; int n,m,p,tot; struct query { int nxt,v,ans,vis; }a[N<<1]; inline int read()//快读 { int x=0; char ch=getchar(); while(ch<'0'||ch>'9') ch=getchar(); while(ch>='0'&&ch<='9') { x=x*10+ch-'0'; ch=getchar(); } return x; } void write(int x)//快输 { if(x>9) write(x/10); putchar(x%10+'0'); } int find(int x)//并查集 { return x==f[x]?x:f[x]=find(f[x]); } inline void addedge(int x,int y)//加边 { v[++tot]=y; nxt[tot]=frt[x];frt[x]=tot; v[++tot]=x; nxt[tot]=frt[y];frt[y]=tot;//反向加边 } inline void addquery(int x,int y)//用一个邻接表存询问 { a[++tot].v=y; a[tot].nxt=head[x];head[x]=tot; a[++tot].v=x; a[tot].nxt=head[y];head[y]=tot; } void dfs(int x)//核心过程 { vis[x]=1; for(int i=frt[x];i;i=nxt[i]) if(!vis[v[i]]) dfs(v[i]),f[v[i]]=x;//等遍历完了v[i]的所有 //子节点后,再令f[v[i]]=x for(int i=head[x];i;i=a[i].nxt)//找关于x的查询 if(vis[a[i].v]&&!a[i].vis) { a[i].ans=find(a[i].v); a[i].vis=1; if(i&1) a[i+1].ans=a[i].ans,a[i+1].vis=1; else a[i-1].ans=a[i].ans,a[i-1].vis=1; //赋值相邻的查询 } } int main() { scanf("%d%d%d",&n,&m,&p); for(int i=1;i<n;i++) { addedge(read(),read()); f[i]=i; } tot=0;f[n]=n; for(int i=1;i<=m;i++) addquery(read(),read()); dfs(p); for(int i=1;i<tot;i+=2) write(a[i].ans),putchar('\n'); }

最慢的一个点只跑了280ms。

总结

一:欧拉序

时间复杂度与空间复杂度都巨大,编程实现复杂度也不小,不建议使用。

二:倍增

时间复杂度较优,空间复杂度巨大,比较好理解,代码也比较好实现,适用与初学者。

三:树链剖分

时间复杂度与空间复杂度都极优,代码也好实现,极力推荐,最好的方法。

四:Tarjan

跑得最快,空间也优,但使用情况受限,只针对离线的情况,适用于毒瘤的卡常题。

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!