树链剖分模板
树链剖分可以把树上的点划分成一条条连续的“链”,“链”是一条简单路径,上面的每个点满足祖先后代关系,从根节点到每个点都只需经过最多\(\log_2n\)条链。 链上的每一个点 dfs序也是连续的, 故可以配合其他数据结构解决很多树上查询问题(近乎所有静态树问题)
先定义siz[u]
为点u
的子树大小。
int siz[MAXN]; //每个点的子树大小 int dep[MAXN]; //每个点的深度 int mxson[MAXN];//每个点的重儿子
链分为重链和轻链。 考虑每个点,它只会和它的一个儿子组成链,这个儿子叫做 重儿子, 其他儿子叫做 轻儿子。 重儿子满足: siz[mxson[u]]
最大.
第一次dfs:
任务清单:
- 求出每个节点 u 的 siz[u], dep[u], fa[u]
- 求出每个节点的重儿子(注意叶节点没有重儿子!)
void dfs1(int u, int faa){ dep[u] = dep[faa] + 1; siz[u] = 1; fa[u] = faa; int mxsiz = 0; for(int e = h[u]; e; e = nxt[e]){ int v = ev[e]; if(v == faa) continue; dfs1(v, u); if(mxsiz < siz[v]) { mxson[u] = v; mxsiz = siz[v]; } siz[u] += siz[v]; } }
第二次dfs: 确定dfs序和整棵树的剖分情况.
任务清单:
- 求出每个点的“新编号”, 所有在线段树上的操作都要使用 新编号 (因为只有使用新编号,每条链上的点\(dfs\)序才是连续的。
- 求出每个点
u
的top[u]
(下面有解释)
int top[u]; //u所在的链的顶端 int dfn[MAXN]; //每个点的dfs序(第二次dfs后) 作为每个点的新编号
int timestamp = 0; void dfs2(int u, int topf){ dfn[u] = ++timestamp; top[u] = topf; if(mxson[u] == 0) return; //VERY IMPORTANT!!!! //上面这句代码不加会RE! (没有儿子的节点不可以dfs!) //或者加一句if(u == 0) return; 也可以 dfs2(mxson[u], topf); for(int e = h[u]; e; e = nxt[e]){ int v = ev[e]; if(v != fa[u] && v != mxson[u]) dfs2(v, v); } }
两遍\(dfs\)后, 对整棵树的剖分已经完成.
易错点
注意!! 注意!! 注意!!
siz[u], top[u], fa[u], dep[u], mxson[u]
中的 \(u\) 是每个点的 旧编号!!!
树链剖分的应用
举个例子: 通过树链剖分实现链上求和等操作.
链上询问
一边求 \(LCA\) 一边使用线段树修改/求和. (因为每条链上的点的\(dfs\)序连续)
所有链上操作都可以基于该模板。 瞪大眼睛,注释非常重要!
#define swapdep(u, v) if(dep[top[u]] < dep[top[v]]) swap(u, v) //u所在链更深 int qchain(int u, int v) { ll ret = 0; //和倍增求LCA一样, 这里求LCA也是让两个节点轮流向上"跳"。 //此时,要求跳的那个点 所在链的深度 较大(因为跳完之后,u = fa[top[u]]) while(top[u] != top[v]) { swapdep(u, v); upd(ret, query(1, 1, n, dfn[top[u]], dfn[u]) ); u = fa[top[u]]; } //两个点在同一条链上以后, 就要根据 点自己的深度 来判断谁先谁后了。(查询[l, r]区间和时, l <= r ) if(dep[u] < dep[v]) swap(u, v); upd(ret, query(1, 1, n, dfn[v], dfn[u]) ); return ret; }
链上修改同理.
子树询问
每个点的子树\(dfs\)序连续, 同样使用线段树修改/求和
注意 siz[u] 中的u是每个点的 初始编号
int qsub(int u){ return query(1, 1, n, dfn[u], dfn[u] + siz[u] - 1); } void updsub(int u, int val){ update(1, 1, n, dfn[u], dfn[u] + siz[u] - 1, val); }
CODE (luogu P3384)
#include<bits/stdc++.h> using namespace std; #define ll long long void write(ll x) { if(x < 0) { putchar('-'); x = -x; } if(x > 9) write(x / 10); putchar(x % 10 + '0'); } ll read(){ ll ret = 0, f = 1; char c = getchar(); for(; !isdigit(c); c = getchar()) if('-'==c) f = -1; for(; isdigit(c); c = getchar()) ret = ret*10 + c - '0'; return ret * f; } #define nl puts("") #define bs putchar(' ') const int _maxn = 150005; int dep[_maxn], siz[_maxn], fa[_maxn], mxson[_maxn]; int dfn[_maxn], top[_maxn]; int n, m, r, p; int a[_maxn]; int h[_maxn], nxt[_maxn*2], ev[_maxn*2], cnte; void adde(int u, int v){ cnte++; ev[cnte] = v; nxt[cnte] = h[u]; h[u] = cnte; } void dfs1(int u, int faa){ dep[u] = dep[faa] + 1; siz[u] = 1; fa[u] = faa; int mxsiz = 0; for(int e = h[u]; e; e = nxt[e]){ int v = ev[e]; if(v == faa) continue; dfs1(v, u); if(mxsiz < siz[v]) { mxson[u] = v; mxsiz = siz[v]; } siz[u] += siz[v]; } } int timestamp = 0; void dfs2(int u, int topf){ dfn[u] = ++timestamp; top[u] = topf; if(mxson[u] == 0) return; //VERY IMPORTANT!!!! dfs2(mxson[u], topf); for(int e = h[u]; e; e = nxt[e]){ int v = ev[e]; if(v != fa[u] && v != mxson[u]) dfs2(v, v); } } //segtree #define il inline ll sum[_maxn*4], tag[_maxn*4]; il int ls(int o) {return o*2;} il int rs(int o) {return o*2+1;} void upd(ll &x, ll y) { x = (p+x+y) % p; } void _op(int o, int l, int r, int x){ upd(tag[o], x); upd(sum[o], (ll)(r-l+1) * x % p); } il void pushdown(int o, int l, int r) { if(tag[o] == 0) return; int mid = (l+r)>>1; _op(ls(o), l, mid, tag[o]); _op(rs(o), mid+1, r, tag[o]); tag[o] = 0; } void update(int o, int l, int r, int ql, int qr, int v) { pushdown(o, l, r); if(ql <= l && r <= qr) { upd(tag[o], v); upd(sum[o], (ll)(r-l+1)*v%p); return; } int mid = (l+r)>>1; if(ql <= mid) update(ls(o), l, mid, ql, qr, v); if(mid < qr) update(rs(o), mid+1, r, ql, qr, v); sum[o] = (sum[ls(o)] + sum[rs(o)]) % p; } ll query(int o, int l, int r, int ql, int qr) { pushdown(o, l, r); if(ql <= l && r <= qr) { return sum[o]; } int mid = (l+r)>>1; ll ret = 0; if(ql <= mid) upd(ret, query(ls(o), l, mid, ql, qr) ); if(mid < qr) upd(ret, query(rs(o), mid+1, r, ql, qr) ); return ret; } //segtree ends here. //下面一句话非常重要!! #define swapdep(u, v) if(dep[top[u]] < dep[top[v]]) swap(u, v) //u所在链更深 int qchain(int u, int v) { ll ret = 0; while(top[u] != top[v]) { swapdep(u, v); upd(ret, query(1, 1, n, dfn[top[u]], dfn[u]) ); u = fa[top[u]]; } if(dep[u] < dep[v]) swap(u, v); upd(ret, query(1, 1, n, dfn[v], dfn[u]) ); return ret; } void updchain(int u, int v, int val) { while(top[u] != top[v]) { swapdep(u, v); update(1, 1, n, dfn[top[u]], dfn[u], val); u = fa[top[u]]; } if(dep[u] < dep[v]) swap(u, v); update(1, 1, n, dfn[v], dfn[u], val); } int qsub(int u){ return query(1, 1, n, dfn[u], dfn[u] + siz[u] - 1); } void updsub(int u, int val){ update(1, 1, n, dfn[u], dfn[u] + siz[u] - 1, val); } signed main(){ #ifndef ONLINE_JUDGE freopen(".in","r",stdin); freopen(".out", "w", stdout); #endif n = read(), m = read(), r = read(), p = read(); for(int i = 1; i <= n; i++) { a[i] = read(); } for(int i = 1; i < n; i++) { int u = read(), v= read(); adde(u, v); adde(v, u); } dfs1(r, 0); dfs2(r, r); // for(int i = 1; i <= n; i++) { // printf("%d mxs=%d fa=%d siz=%d dfn=%d top=%d\n", i, mxson[i], fa[i],siz[i], dfn[i], top[i]); // } for(int i = 1; i <= n; i++) { update(1, 1, n, dfn[i], dfn[i], a[i]); assert(top[i] != 0); } while(m--){ int cas, x, y, z; cas = read(), x = read(); switch(cas){ case 1: y = read(), z = read(); updchain(x, y, z); break; case 2: y = read(); write(qchain(x, y)); nl; break; case 3: z = read(); updsub(x, z); break; case 4: write(qsub(x)); nl; } } fclose(stdin), fclose(stdout); return 0; }