树链剖分——解决树上路径问题利器

我与影子孤独终老i 提交于 2020-04-14 13:19:10

【推荐阅读】微服务还能火多久?>>>

1 树剖介绍

树链剖分是用来处理树上路径问题。(如路径和)

这里以P2590 [ZJOI2008]树的统计为模板来讲解。

因为有修改所以暴力是肯定不行的。这样我们就要请出我们今天的主角——树链剖分了。

树链剖分其实就是把一颗数剖成很多条链,给链上的节点重新编号使其编号连续。这样就可在这条链上用其它处理线段的数据结构(一般是线段树)处理了。

常用的有轻重链剖分。就是把树剖成多条重链和轻链。

对于每一个节点,他会有一个重儿子和若干个轻儿子。重儿子就延续当前的重链,轻儿子则作为新的一条重链的开始。重儿子是指子树大小最大的儿子,其他的都是轻儿子。

这样这棵树就被剖成了很多条链。显然每个节点都属于一条重链。

重儿子指的就是子树大小最大的儿子,轻儿子是其它儿子。

如图,红边是重链,蓝边是轻链,星星是重链开始

我们还需要重新编号,其实只要按照重儿子dfs的dfs序即可。

为了后面的查询,我们还得存储一个dep代表深度,fa代表在树上的父亲以及bl代表当前重链的开始。

我们来看具体的代码实现。

代码有两个dfs(具体看注释)。

void dfs1(int x) {
    sz[x] = 1;//sz是存储子树大小的
    for (int i = head[x]; i; i = edge[i].pre) {
        int y = edge[i].to;
        if (y == fa[x]) continue;
                //预处理深度和父节点  
        dep[y] = dep[x] + 1;
        fa[y] = x;
        dfs1(y, x);
        sz[x] += sz[y];//统计子树大小
    }
}
void dfs2(int x, int chain) {
    int k = 0;//重儿子
    dfn[x] = ++len;
    bl[x] = chain;//chain是树链开始
    for (int i = head[x]; i; i = edge[i].pre) {
        int y = edge[i].to;
        if (dep[y] < dep[x]) continue;
        if (sz[y] > sz[k]) {
            k = y;
        }
    }//查找重儿子
    if (k) dfs2(k, chain);//如果有重儿子则先递归重儿子
    for (int i = head[x]; i; i = edge[i].pre) {
        int y = edge[i].to;
        if (y == fa[x] || y == k) continue;
        dfs2(y, y);//轻儿子是重链的开始
    }
}

那么如何查询呢?

其实也很简单

我们以查询“和”为例来介绍。

假设我们我们要查询 $u->v$  的路径上的“和”。

我们首先判断它们在不在同一条重链上

如果在,直接线段树查询这一条链,返回即可。

如果不在,则需要其中一个节点 $u$ 统计从 $bl[u]->u$ 这条链上的答案,然后 $u$ 跳到 $fa[bl[u]]$。

记住这个 $u$ 必须是 $bl的dep$大的一个,否则有可能得不到正确答案。

比如上图,如果跳 $v$ 则会跳到 $root$ 就的不到答案而且会死循环。

代码实现

int Qsum(int u, int v) {
    int sum = 0;
    while (bl[u] != bl[v]) {
        if (dep[bl[u]] < dep[bl[v]]) swap(u, v);
        sum += ask_sum(1, pos[bl[u]], pos[u]);//ask是线段树查询区间和
        u = fa[bl[u]];
    }
    if (dep[u] > dep[v]) {
        swap(u, v);
    }
    sum += ask_sum(1, pos[u], pos[v]);//在一条链上也别忘统计
    return sum;
}

注意:线段树维护就是正常的区间维护(对dfn)。

时间复杂度分析:

在重链上线段树求答案是 $O(\log{N})$ 的,而对于跳轻链,因为轻儿子的子树大小至多是一半,所以级别是 $O(\log{N})$ 的。

完整代码:

#include <iostream>
#include <cstdio>
using namespace std;
const int N = 30010, inf = 0x7f7f7f7f;
struct seg_tree{
    int maxi, val, l, r;
}st[4 * N];
struct node{
    int pre, to;
}edge[2 * N];
int head[N], tot;
int n;
int a, b, dep[N], fa[N], bl[N], sz[N], w[N], pos[N];
int len, QQ;
void dfs1(int x, int f) {
    sz[x] = 1;
    for (int i = head[x]; i; i = edge[i].pre) {
        int y = edge[i].to;
        if (y == f) continue;
        dep[y] = dep[x] + 1;
        fa[y] = x;
        dfs1(y, x);
        sz[x] += sz[y];
    }
}
void dfs2(int x, int chain) {
    int k = 0;
    pos[x] = ++len;
    bl[x] = chain;
    for (int i = head[x]; i; i = edge[i].pre) {
        int y = edge[i].to;
        if (dep[y] < dep[x]) continue;
        if (sz[y] > sz[k]) {
            k = y;
        }
    }
    if (k) dfs2(k, chain);
    for (int i = head[x]; i; i = edge[i].pre) {
        int y = edge[i].to;
        if (dep[y] < dep[x] || y == k) continue;
        dfs2(y, y);
    }
}
void build(int x, int l, int r) {
    st[x].l = l, st[x].r = r;
    if (l == r) return;
    int mid = (l + r) >> 1;
    build(x << 1, l, mid);
    build(x << 1 | 1, mid + 1, r);
}
void change(int x, int p, int v) {
    int l = st[x].l, r = st[x].r;
    if (l == r) {
        st[x].val = st[x].maxi = v;
        return;
    }
    int mid = (l + r) >> 1;
    if (p <= mid) change(x << 1, p, v);
    else change(x << 1 | 1, p, v);
    st[x].val = st[x << 1].val + st[x << 1 | 1].val;
    st[x].maxi = max(st[x << 1].maxi, st[x << 1 | 1].maxi);
}
int ask_max(int x, int L, int R) {
    int l = st[x].l, r = st[x].r;
    if (L <= l && r <= R) {
        return st[x].maxi;
    }
    if (l > R || r < L) return -inf;
    return max(ask_max(x << 1, L, R), ask_max(x << 1 | 1, L, R));
}
int ask_sum(int x, int L, int R) {
    int l = st[x].l, r = st[x].r;
    if (L <= l && r <= R) {
        return st[x].val;
    }
    if (l > R || r < L) return 0;
    return ask_sum(x << 1, L, R) + ask_sum(x << 1 | 1, L, R);
}
int Qsum(int u, int v) {
    int sum = 0;
    while (bl[u] != bl[v]) {
        if (dep[bl[u]] < dep[bl[v]]) swap(u, v);
        sum += ask_sum(1, pos[bl[u]], pos[u]);
        u = fa[bl[u]];
    }
    if (dep[u] > dep[v]) {
        swap(u, v);
    }
    sum += ask_sum(1, pos[u], pos[v]);
    return sum;
}
int Qmax(int u, int v) {
    int maxi = -inf;
    while (bl[u] != bl[v]) {
        if (dep[bl[u]] < dep[bl[v]]) swap(u, v);
        maxi = max(maxi, ask_max(1, pos[bl[u]], pos[u]));
        u = fa[bl[u]];
    }
    if (dep[u] > dep[v]) {
        swap(u, v);
    }
    maxi = max(maxi, ask_max(1, pos[u], pos[v]));
    return maxi;
}
void add(int u, int v) {
    edge[++tot] = node{head[u], v};
    head[u] = tot;
}
int main() {
    cin >> n;
    for (int i = 1, a, b; i < n; i++) {
        cin >> a >> b;
        add(a, b);
        add(b, a);
    }
    dfs1(1, 0);
    dfs2(1, 1);
    build(1, 1, len);
    for (int i = 1; i <= n; i++) {
        cin >> w[i];
        change(1, pos[i], w[i]);
    }
    cin >> QQ;
    while (QQ--) {
        string opt;
        int u, v;
        cin >> opt >> u >> v;
        if (opt == "QMAX") {
            cout << Qmax(u, v) << "\n";
        } else if (opt == "QSUM") {
            cout << Qsum(u, v) << "\n";
        } else {
            change(1, pos[u], v);
        }
    }
    return 0;
} 

练习:

因为时间有限,所以部分习题不配有题解。如有需要,请在下方留言谢谢。
一些须知

P2486 [SDOI2011]染色

2 子树问题

我们来看这道P3384 【模板】轻重链剖分,它还叫我们输出子树和,这可怎么办呢?

我们发现按照上述方式给节点编号那么一棵子树的编号也是连续的(即一段区间),那么我们直接区间查询修改即可。

练习:

P3178 [HAOI2015]树上操作 题解

P2146 [NOI2015]软件包管理器 题解

3 LCA

P3379 【模板】最近公共祖先(LCA)

理解了上面讲的东西那这个就很好想了。

一直跳到两点在同一重链上,然后返回较浅的点,$O(\log{N})$。

int LCA(int u, int v) {
    while (bl[u] != bl[v]) {
        if (dep[bl[u]] < dep[bl[v]]) swap(u, v);
        u = fa[bl[u]];
    }
    return dep[u] < dep[v] ? u : v;
}

树剖求LCA的常数较小,可用来卡常。

习题:

P3258 [JLOI2014]松鼠的新家

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!