裸的求公共祖先的题目。
思路:用被增法求最近公共祖先,用fa[i][j]表示从i开始,向上走2j 步能走到的所有节点,其中1 <= j <= logn(下取整)
例如上图,f[6][0] = 4, f[6][1] = 2, f[6][2] = -1表示不存在。
做法:1.首先我们需要预处理一个fa数组,采用递推的方式。fa[i][j]表示从i开始向上走2j 步,那么我们可以拆成两部分,先走2j -1 步再走2j - 1 步,也就是fa[fa[i][j - 1][j - 1]。于是我们便得到递推公式fa[i][i] = fa[fa[i][j - 1]][j - 1]。只需将j从小到大枚举一遍就可以了。
ps:本质是用二进制拼凑路径长度,和多重背包的二进制优化思想大致。
2.同时我们还需要预处理一个depth数组,来表示当前点的深度,例如depth[1] = 1.depth[6] = 4。就是到根节点的距离+1。同时我们再设置两个“哨兵“,如果从i跳过了根节点,那么fa[i][j] = 0, depth[0] = 0。
3.预处理完两个数组后,进行操作,先将两个点跳到同一层,我们就统一将a看做较低的节点。b看做较高的节点。
4.两个节点所在层数相同后,如果还不是公共祖先,就一起往上跳到最近公共祖先的下一层。(比较好判断,因为当f[a][k] = f[b][k]时,表示跳到了公共节点,但不一定是最近的,而跳到第一个公共节点前一层,再往上就一定是最近的公共节点,即fa[a][0]就是答案。
预处理的复杂度是O(logn),查询也是O(logn)。
1 #include <iostream> 2 #include <algorithm> 3 #include <cstring> 4 #include <queue> 5 6 using namespace std; 7 8 const int N = 40010, M = 2 * N; 9 10 int e[M], ne[M], h[M], idx; 11 int depth[N], fa[N][16]; 12 13 void add(int a, int b) 14 { 15 e[idx] = b, ne[idx] = h[a], h[a] = idx ++; 16 } 17 18 void bfs(int root) 19 { 20 memset(depth, 0x3f, sizeof depth); 21 queue<int> q; 22 23 q.push(root); 24 depth[0] = 0, depth[root] = 1;//初始化 25 26 while(q.size()) 27 { 28 int t = q.front(); 29 q.pop(); 30 31 for(int i = h[t] ; ~i ; i = ne[i]) 32 { 33 int j = e[i]; 34 if(depth[j] > depth[t] + 1) 35 { 36 depth[j] = depth[t] + 1; 37 fa[j][0] = t; 38 q.push(j); 39 for(int k = 1 ; k <= 15 ; k ++) 40 fa[j][k] = fa[fa[j][k - 1]][k - 1]; 41 } 42 } 43 } 44 } 45 46 int lca(int a, int b) 47 { 48 if(depth[a] < depth[b])swap(a, b);//默认a都是深度较深的节点 49 for(int k = 15 ; k >= 0 ; k --) 50 if(depth[fa[a][k]] >= depth[b])//将a跳到和b同一层,因为默认depth都是>=1的,所以depth[] = 0时不会满足条件。“哨兵”的用处。 51 a = fa[a][k]; 52 if(a == b)return a;//相同了就返回其中一个,就是最近祖先节点 53 for(int k = 15 ; k >= 0 ; k --) 54 if(fa[a][k] != fa[b][k]) 55 { 56 a = fa[a][k]; 57 b = fa[b][k]; 58 } 59 return fa[a][0]; 60 } 61 62 63 int main(){ 64 int n, m, root = 0; 65 cin >> n; 66 memset(h, -1, sizeof h); 67 68 while(n --) 69 { 70 int a, b; 71 cin >> a >> b; 72 if(b == -1)root = a; 73 else add(a, b), add(b, a); 74 } 75 76 bfs(root); 77 78 cin >> m; 79 while(m --) 80 { 81 int a, b; 82 cin >> a >> b; 83 int p = lca(a, b); 84 if(p == a)puts("1"); 85 else if(p == b)puts("2"); 86 else puts("0"); 87 } 88 return 0; 89 }
求树上两点间的距离,可以转换成求两点间的公共祖先的方法来解决。
这里用到tarjan算法(对求公共祖先的向上标记法的优化)。
j
将树上的节点分为三类:1.已经遍历过的节点(绿色)。2.正在遍历的节点(红色)。3.还未遍历的节点(紫色)。
我们可以发现,求红色部分和绿色部分的最近公共祖先,就是红色部分上的那些节点,也就是橙色部分圈起来的子树,可以合并到他的祖宗节点去。(集合合并)所以我们可以用并查集做。
除此之外,我们还需预处理一个dis数组,dis[i]表示根节点到i的距离。
之后求x和y之间的距离,只需要dis[x] + dis[y] - 2 * dis[lca],lca是x和y的最近公共祖先。
并查集合并和查询的操作都是O(1)所以算法复杂度是线性的O(n + m)。
代码:
1 #include <iostream> 2 #include <algorithm> 3 #include <cstring> 4 #include <vector> 5 6 using namespace std; 7 8 typedef pair<int, int> PII; 9 10 const int N = 10010, M = 2 * N; 11 12 int e[M], ne[M], w[M], h[M], idx; 13 int ans[M], dis[N], st[N], f[N]; 14 vector<PII> query[N]; 15 int n, m; 16 17 void add(int a, int b, int c) 18 { 19 e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++; 20 } 21 22 int find(int x) 23 { 24 return f[x] == x ? x : f[x] = find(f[x]); 25 } 26 27 28 void dfs(int u, int fa) 29 { 30 for(int i = h[u] ; ~i ; i = ne[i]) 31 { 32 int j = e[i]; 33 if(j == fa)continue; 34 dis[j] = dis[u] + w[i]; 35 dfs(j, u); 36 } 37 } 38 39 void tarjan(int u) 40 { 41 st[u] = true; 42 for(int i = h[u] ; ~i ; i = ne[i]) 43 { 44 int j = e[i]; 45 if(!st[j]) 46 { 47 tarjan(j); 48 f[j] = u;//将子树合并到祖先节点 49 } 50 } 51 52 for(auto item : query[u]) 53 { 54 int y = item.first, id = item.second; 55 if(st[y] == 2)//标记为2说明是完成搜索的 56 { 57 int lca = find(y); 58 ans[id] = dis[u] + dis[y] - 2 * dis[lca]; 59 } 60 } 61 62 st[u] = 2;//完成搜索后标记为2 63 } 64 65 int main(){ 66 cin >> n >> m; 67 68 memset(h, -1, sizeof h); 69 70 for(int i = 0 ; i < n - 1; i ++) 71 { 72 int a, b, c; 73 cin >> a >> b >> c; 74 add(a, b, c), add(b, a, c); 75 } 76 77 for(int i = 0 ; i < m ; i ++) 78 { 79 int a, b; 80 cin >> a >> b; 81 query[a].push_back({b, i});//存入和这个点相关的点和查询的下标 82 query[b].push_back({a, i}); 83 } 84 85 for(int i = 1 ; i <= n ; i ++)f[i] = i; 86 87 dfs(1, -1);//预处理dis数组 88 tarjan(1); 89 90 for(int i = 0 ; i < m ; i ++)cout << ans[i] << endl; 91 92 return 0; 93 }