如何写树链剖分

試著忘記壹切 提交于 2019-12-04 11:06:41

树链剖分模板

树链剖分可以把树上的点划分成一条条连续的“链”,“链”是一条简单路径,上面的每个点满足祖先后代关系,从根节点到每个点都只需经过最多\(\log_2n\)条链。 链上的每一个点 dfs序也是连续的, 故可以配合其他数据结构解决很多树上查询问题(近乎所有静态树问题)

先定义siz[u]为点u的子树大小。

int siz[MAXN]; //每个点的子树大小
int dep[MAXN]; //每个点的深度
int mxson[MAXN];//每个点的重儿子

链分为重链轻链。 考虑每个点,它只会和它的一个儿子组成链,这个儿子叫做 重儿子, 其他儿子叫做 轻儿子。 重儿子满足: siz[mxson[u]]最大.

第一次dfs:

任务清单:

  1. 求出每个节点 usiz[u], dep[u], fa[u]
  2. 求出每个节点的重儿子(注意叶节点没有重儿子!)
void dfs1(int u, int faa){
    dep[u] = dep[faa] + 1;
    siz[u] = 1;
    fa[u] = faa;
    int mxsiz = 0;
    for(int e = h[u]; e; e = nxt[e]){
        int v = ev[e];
        if(v == faa) continue;
        dfs1(v, u);
        if(mxsiz < siz[v]) {
            mxson[u] = v;
            mxsiz = siz[v];
        } 
        siz[u] += siz[v];
    }
}

第二次dfs: 确定dfs序和整棵树的剖分情况.

任务清单:

  1. 求出每个点的“新编号”, 所有在线段树上的操作都要使用 新编号 (因为只有使用新编号,每条链上的点\(dfs\)序才是连续的。
  2. 求出每个点utop[u](下面有解释)
int top[u]; //u所在的链的顶端
int dfn[MAXN]; //每个点的dfs序(第二次dfs后) 作为每个点的新编号
int timestamp = 0;
void dfs2(int u, int topf){
    dfn[u] = ++timestamp;
    top[u] = topf;
    if(mxson[u] == 0) return; //VERY IMPORTANT!!!!
    //上面这句代码不加会RE! (没有儿子的节点不可以dfs!)
    //或者加一句if(u == 0) return; 也可以
    dfs2(mxson[u], topf);
    for(int e = h[u]; e; e = nxt[e]){
        int v = ev[e];
        if(v != fa[u] && v != mxson[u]) dfs2(v, v);
    }
}

两遍\(dfs\)后, 对整棵树的剖分已经完成.

易错点

注意!! 注意!! 注意!!

siz[u], top[u], fa[u], dep[u], mxson[u] 中的 \(u\) 是每个点的 旧编号!!!

树链剖分的应用

举个例子: 通过树链剖分实现链上求和等操作.

链上询问

一边求 \(LCA\) 一边使用线段树修改/求和. (因为每条链上的点的\(dfs\)序连续)

所有链上操作都可以基于该模板。 瞪大眼睛,注释非常重要!

#define swapdep(u, v) if(dep[top[u]] < dep[top[v]]) swap(u, v) //u所在链更深
int qchain(int u, int v) {
    ll ret = 0;
    //和倍增求LCA一样, 这里求LCA也是让两个节点轮流向上"跳"。 
    //此时,要求跳的那个点   所在链的深度   较大(因为跳完之后,u = fa[top[u]])
    while(top[u] != top[v]) {
        swapdep(u, v);
        upd(ret, query(1, 1, n, dfn[top[u]], dfn[u]) );
        u = fa[top[u]];
    }
    //两个点在同一条链上以后, 就要根据  点自己的深度  来判断谁先谁后了。(查询[l, r]区间和时, l <= r )
    if(dep[u] < dep[v]) swap(u, v);
    upd(ret, query(1, 1, n, dfn[v], dfn[u]) );
    return ret;
}

链上修改同理.

子树询问

每个点的子树\(dfs\)序连续, 同样使用线段树修改/求和

注意 siz[u] 中的u是每个点的 初始编号

int qsub(int u){
    return query(1, 1, n, dfn[u], dfn[u] + siz[u] - 1);
}
void updsub(int u, int val){
    update(1, 1, n, dfn[u], dfn[u] + siz[u] - 1, val);
}

CODE (luogu P3384)

#include<bits/stdc++.h>
using namespace std;

#define ll long long
void write(ll x) {
    if(x < 0) {
        putchar('-');
        x = -x;
    } 
    if(x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
ll read(){
    ll ret = 0, f = 1; char c = getchar();
    for(; !isdigit(c); c = getchar()) if('-'==c) f = -1;
    for(; isdigit(c); c = getchar()) ret = ret*10 + c - '0';
    return ret * f;
}
#define nl puts("")
#define bs putchar(' ')
const int _maxn = 150005;
int dep[_maxn], siz[_maxn], fa[_maxn], mxson[_maxn];
int dfn[_maxn], top[_maxn];
int n, m, r, p;
int a[_maxn];
int h[_maxn], nxt[_maxn*2], ev[_maxn*2], cnte;
void adde(int u, int v){
    cnte++;
    ev[cnte] = v;
    nxt[cnte] = h[u];
    h[u] = cnte;
}
void dfs1(int u, int faa){
    dep[u] = dep[faa] + 1;
    siz[u] = 1;
    fa[u] = faa;
    int mxsiz = 0;
    for(int e = h[u]; e; e = nxt[e]){
        int v = ev[e];
        if(v == faa) continue;
        dfs1(v, u);
        if(mxsiz < siz[v]) {
            mxson[u] = v;
            mxsiz = siz[v];
        } 
        siz[u] += siz[v];
    }
}
int timestamp = 0;
void dfs2(int u, int topf){
    dfn[u] = ++timestamp;
    top[u] = topf;
    if(mxson[u] == 0) return; //VERY IMPORTANT!!!!
    dfs2(mxson[u], topf);
    for(int e = h[u]; e; e = nxt[e]){
        int v = ev[e];
        if(v != fa[u] && v != mxson[u]) dfs2(v, v);
    }
}

//segtree
#define il inline
    ll sum[_maxn*4], tag[_maxn*4];
    il int ls(int o) {return o*2;}
    il int rs(int o) {return o*2+1;}
    void upd(ll &x, ll y) {
        x = (p+x+y) % p;
    }
    void _op(int o, int l, int r, int x){
        upd(tag[o], x);
        upd(sum[o], (ll)(r-l+1) * x % p);
    }
    il void pushdown(int o, int l, int r) {
        if(tag[o] == 0) return;
        int mid = (l+r)>>1;
        _op(ls(o), l, mid, tag[o]);
        _op(rs(o), mid+1, r, tag[o]);
        tag[o] = 0;
    }
    void update(int o, int l, int r, int ql, int qr, int v) {
        pushdown(o, l, r);
        if(ql <= l && r <= qr) {
            upd(tag[o], v);
            upd(sum[o], (ll)(r-l+1)*v%p);
            return;
        }
        int mid = (l+r)>>1;
        if(ql <= mid) update(ls(o), l, mid, ql, qr, v);
        if(mid < qr) update(rs(o), mid+1, r, ql, qr, v);
        sum[o] = (sum[ls(o)] + sum[rs(o)]) % p;
    }
    ll query(int o, int l, int r, int ql, int qr) {
        pushdown(o, l, r);
        if(ql <= l && r <= qr) {
            return sum[o];
        }
        int mid = (l+r)>>1;
        ll ret = 0;
        if(ql <= mid) upd(ret, query(ls(o), l, mid, ql, qr) );
        if(mid < qr) upd(ret, query(rs(o), mid+1, r, ql, qr) );
        return ret;
    }
//segtree ends here.

//下面一句话非常重要!!
#define swapdep(u, v) if(dep[top[u]] < dep[top[v]]) swap(u, v) //u所在链更深
int qchain(int u, int v) {
    ll ret = 0;
    while(top[u] != top[v]) {
        swapdep(u, v);
        upd(ret, query(1, 1, n, dfn[top[u]], dfn[u]) );
        u = fa[top[u]];
    }
    if(dep[u] < dep[v]) swap(u, v);
    upd(ret, query(1, 1, n, dfn[v], dfn[u]) );
    return ret;
}
void updchain(int u, int v, int val) {
    while(top[u] != top[v]) {
        swapdep(u, v);
        update(1, 1, n, dfn[top[u]], dfn[u], val);
        u = fa[top[u]];
    }
    if(dep[u] < dep[v]) swap(u, v);
    update(1, 1, n, dfn[v], dfn[u], val);
}
int qsub(int u){
    return query(1, 1, n, dfn[u], dfn[u] + siz[u] - 1);
}
void updsub(int u, int val){
    update(1, 1, n, dfn[u], dfn[u] + siz[u] - 1, val);
}
signed main(){
    #ifndef ONLINE_JUDGE
    freopen(".in","r",stdin);
    freopen(".out", "w", stdout);
    #endif
    n = read(), m = read(), r = read(), p = read();
    for(int i = 1; i <= n; i++) {
        a[i] = read();
    }
    for(int i = 1; i < n; i++) {
        int u = read(), v= read();
        adde(u, v); adde(v, u);
    }

    dfs1(r, 0);
    dfs2(r, r);
// for(int i = 1; i <= n; i++) {
//     printf("%d mxs=%d fa=%d siz=%d dfn=%d top=%d\n", i, mxson[i], fa[i],siz[i], dfn[i], top[i]);
// }
    for(int i = 1; i <= n; i++) {
        update(1, 1, n, dfn[i], dfn[i], a[i]);
        assert(top[i] != 0);
    }
    while(m--){
        int cas, x, y, z;
        cas = read(), x = read();
        switch(cas){
            case 1: 
                y = read(), z = read();
                updchain(x, y, z);
                break;
            case 2:
                y = read();
                write(qchain(x, y)); nl;
                break;
            case 3:
                z = read();
                updsub(x, z);
                break;
            case 4:
                write(qsub(x)); nl;
        }
    }
    fclose(stdin), fclose(stdout);
    return 0;
}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!