1 树剖介绍
树链剖分是用来处理树上路径问题。(如路径和)
这里以P2590 [ZJOI2008]树的统计为模板来讲解。
因为有修改所以暴力是肯定不行的。这样我们就要请出我们今天的主角——树链剖分了。
树链剖分其实就是把一颗数剖成很多条链,给链上的节点重新编号使其编号连续。这样就可在这条链上用其它处理线段的数据结构(一般是线段树)处理了。
常用的有轻重链剖分。就是把树剖成多条重链和轻链。
对于每一个节点,他会有一个重儿子和若干个轻儿子。重儿子就延续当前的重链,轻儿子则作为新的一条重链的开始。重儿子是指子树大小最大的儿子,其他的都是轻儿子。
这样这棵树就被剖成了很多条链。显然每个节点都属于一条重链。
重儿子指的就是子树大小最大的儿子,轻儿子是其它儿子。
如图,红边是重链,蓝边是轻链,星星是重链开始
我们还需要重新编号,其实只要按照重儿子dfs的dfs序即可。
为了后面的查询,我们还得存储一个dep代表深度,fa代表在树上的父亲以及bl代表当前重链的开始。
我们来看具体的代码实现。
代码有两个dfs(具体看注释)。
void dfs1(int x) {
sz[x] = 1;//sz是存储子树大小的
for (int i = head[x]; i; i = edge[i].pre) {
int y = edge[i].to;
if (y == fa[x]) continue;
//预处理深度和父节点
dep[y] = dep[x] + 1;
fa[y] = x;
dfs1(y, x);
sz[x] += sz[y];//统计子树大小
}
}
void dfs2(int x, int chain) {
int k = 0;//重儿子
dfn[x] = ++len;
bl[x] = chain;//chain是树链开始
for (int i = head[x]; i; i = edge[i].pre) {
int y = edge[i].to;
if (dep[y] < dep[x]) continue;
if (sz[y] > sz[k]) {
k = y;
}
}//查找重儿子
if (k) dfs2(k, chain);//如果有重儿子则先递归重儿子
for (int i = head[x]; i; i = edge[i].pre) {
int y = edge[i].to;
if (y == fa[x] || y == k) continue;
dfs2(y, y);//轻儿子是重链的开始
}
}
那么如何查询呢?
其实也很简单
我们以查询“和”为例来介绍。
假设我们我们要查询 $u->v$ 的路径上的“和”。
我们首先判断它们在不在同一条重链上
如果在,直接线段树查询这一条链,返回即可。
如果不在,则需要其中一个节点 $u$ 统计从 $bl[u]->u$ 这条链上的答案,然后 $u$ 跳到 $fa[bl[u]]$。
记住这个 $u$ 必须是 $bl的dep$大的一个,否则有可能得不到正确答案。
比如上图,如果跳 $v$ 则会跳到 $root$ 就的不到答案而且会死循环。
代码实现
int Qsum(int u, int v) {
int sum = 0;
while (bl[u] != bl[v]) {
if (dep[bl[u]] < dep[bl[v]]) swap(u, v);
sum += ask_sum(1, pos[bl[u]], pos[u]);//ask是线段树查询区间和
u = fa[bl[u]];
}
if (dep[u] > dep[v]) {
swap(u, v);
}
sum += ask_sum(1, pos[u], pos[v]);//在一条链上也别忘统计
return sum;
}
注意:线段树维护就是正常的区间维护(对dfn)。
时间复杂度分析:
在重链上线段树求答案是 $O(\log{N})$ 的,而对于跳轻链,因为轻儿子的子树大小至多是一半,所以级别是 $O(\log{N})$ 的。
完整代码:
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 30010, inf = 0x7f7f7f7f;
struct seg_tree{
int maxi, val, l, r;
}st[4 * N];
struct node{
int pre, to;
}edge[2 * N];
int head[N], tot;
int n;
int a, b, dep[N], fa[N], bl[N], sz[N], w[N], pos[N];
int len, QQ;
void dfs1(int x, int f) {
sz[x] = 1;
for (int i = head[x]; i; i = edge[i].pre) {
int y = edge[i].to;
if (y == f) continue;
dep[y] = dep[x] + 1;
fa[y] = x;
dfs1(y, x);
sz[x] += sz[y];
}
}
void dfs2(int x, int chain) {
int k = 0;
pos[x] = ++len;
bl[x] = chain;
for (int i = head[x]; i; i = edge[i].pre) {
int y = edge[i].to;
if (dep[y] < dep[x]) continue;
if (sz[y] > sz[k]) {
k = y;
}
}
if (k) dfs2(k, chain);
for (int i = head[x]; i; i = edge[i].pre) {
int y = edge[i].to;
if (dep[y] < dep[x] || y == k) continue;
dfs2(y, y);
}
}
void build(int x, int l, int r) {
st[x].l = l, st[x].r = r;
if (l == r) return;
int mid = (l + r) >> 1;
build(x << 1, l, mid);
build(x << 1 | 1, mid + 1, r);
}
void change(int x, int p, int v) {
int l = st[x].l, r = st[x].r;
if (l == r) {
st[x].val = st[x].maxi = v;
return;
}
int mid = (l + r) >> 1;
if (p <= mid) change(x << 1, p, v);
else change(x << 1 | 1, p, v);
st[x].val = st[x << 1].val + st[x << 1 | 1].val;
st[x].maxi = max(st[x << 1].maxi, st[x << 1 | 1].maxi);
}
int ask_max(int x, int L, int R) {
int l = st[x].l, r = st[x].r;
if (L <= l && r <= R) {
return st[x].maxi;
}
if (l > R || r < L) return -inf;
return max(ask_max(x << 1, L, R), ask_max(x << 1 | 1, L, R));
}
int ask_sum(int x, int L, int R) {
int l = st[x].l, r = st[x].r;
if (L <= l && r <= R) {
return st[x].val;
}
if (l > R || r < L) return 0;
return ask_sum(x << 1, L, R) + ask_sum(x << 1 | 1, L, R);
}
int Qsum(int u, int v) {
int sum = 0;
while (bl[u] != bl[v]) {
if (dep[bl[u]] < dep[bl[v]]) swap(u, v);
sum += ask_sum(1, pos[bl[u]], pos[u]);
u = fa[bl[u]];
}
if (dep[u] > dep[v]) {
swap(u, v);
}
sum += ask_sum(1, pos[u], pos[v]);
return sum;
}
int Qmax(int u, int v) {
int maxi = -inf;
while (bl[u] != bl[v]) {
if (dep[bl[u]] < dep[bl[v]]) swap(u, v);
maxi = max(maxi, ask_max(1, pos[bl[u]], pos[u]));
u = fa[bl[u]];
}
if (dep[u] > dep[v]) {
swap(u, v);
}
maxi = max(maxi, ask_max(1, pos[u], pos[v]));
return maxi;
}
void add(int u, int v) {
edge[++tot] = node{head[u], v};
head[u] = tot;
}
int main() {
cin >> n;
for (int i = 1, a, b; i < n; i++) {
cin >> a >> b;
add(a, b);
add(b, a);
}
dfs1(1, 0);
dfs2(1, 1);
build(1, 1, len);
for (int i = 1; i <= n; i++) {
cin >> w[i];
change(1, pos[i], w[i]);
}
cin >> QQ;
while (QQ--) {
string opt;
int u, v;
cin >> opt >> u >> v;
if (opt == "QMAX") {
cout << Qmax(u, v) << "\n";
} else if (opt == "QSUM") {
cout << Qsum(u, v) << "\n";
} else {
change(1, pos[u], v);
}
}
return 0;
}
练习:
因为时间有限,所以部分习题不配有题解。如有需要,请在下方留言谢谢。
2 子树问题
我们来看这道P3384 【模板】轻重链剖分,它还叫我们输出子树和,这可怎么办呢?
我们发现按照上述方式给节点编号那么一棵子树的编号也是连续的(即一段区间),那么我们直接区间查询修改即可。
练习:
3 LCA
理解了上面讲的东西那这个就很好想了。
一直跳到两点在同一重链上,然后返回较浅的点,$O(\log{N})$。
int LCA(int u, int v) {
while (bl[u] != bl[v]) {
if (dep[bl[u]] < dep[bl[v]]) swap(u, v);
u = fa[bl[u]];
}
return dep[u] < dep[v] ? u : v;
}
树剖求LCA的常数较小,可用来卡常。
习题:
来源:oschina
链接:https://my.oschina.net/u/4312789/blog/3234232