【模板】树链剖分

百般思念 提交于 2020-01-20 00:36:48

树链剖分:

用于解决一系列维护静态树上信息的问题。这些问题看起来非常像一些区间操作搬到了树上。

(例如:一棵带权树,需要维护修改权值操作以及从$u$到$v$简单路径上的权值和)

树链剖分就是通过某种策略(一般是轻、重边剖分)将原树链划分成若干条链,每条链相当于一个序列,此时就可以用区间数据结构(一般是线段树)维护这些链。

 

需要维护的值:

$f(x)$:$x$在树中的父亲。

$dep(x)$:$x$在树中的深度。

$siz(x)$:$x$的子树大小。

$son(x)$:$u$的重儿子:在$u$的所有儿子中$siz$值最大的儿子,$u\rightarrow v$为重边。

($u$的轻儿子:在$u$的所有儿子中除了重儿子以外的儿子,$u\rightarrow v$为轻边。)

$top(x)$:$x$所在重路径的顶部节点。

$seg(x)$:$x$在线段树中的位置(下标)。

$rnk(x)$:线段树中$x$位置对应的树中节点编号,即有$rnk(seg(x))=x$。

 

轻重边的一些性质:

1、如果$u\rightarrow v$为轻边,则$siz(v)<=siz(u)/2$。

证明:反证法,若存在$siz(v)>siz(u)/2$且存在$siz(v_0)>siz(v)$,那么$siz(v)+siz(v_0)>siz(u)$,即子节点的$siz$和大于父节点的$siz$。

2、从根到任何点$u$的路径上轻边的条数不超过$log(N)$。

证明:由1可知从根到$u$的路径上每经过一条轻边,当前子树的节点个数至少会少$\frac{1}{2}$,所以至多减少$log(N)$次$siz$值为0,到达叶节点。

3、从根到任何点$u$的路径上轻边、重边的条数均不超过$log(N)$。

证明:每条重链的起点和终点都连接一条轻边,由2可知轻边条数不超过$log(N)$,所以重链条数也不超过$log(N)$。

 

实现步骤:

1、一遍$dfs$得到前4个值,再一遍$dfs$将树的节点重新排序,使一条重链上的点$dfs$序连续。

2、使用线段树维护新树的$dfs$序序列,查询时沿重链走到两点的$lca$并计算答案。

 

模板题目:loj10138

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>

using namespace std;
#define MAXN 100005
#define MAXM 500005
#define INF 0x7fffffff
#define ll long long

int hd[MAXN],to[MAXN<<1],top[MAXN];
int A[MAXN],nxt[MAXN<<1],cnt,tot;
int f[MAXN],siz[MAXN],son[MAXN];
int seg[MAXN],rnk[MAXN],dep[MAXN];
struct node{int l,r,sum,mx;}tr[MAXN<<2];
char str[10];

inline int read(){
    int x=0,f=1;
    char c=getchar();
    for(;!isdigit(c);c=getchar())
        if(c=='-')
            f=-1;
    for(;isdigit(c);c=getchar())
        x=x*10+c-'0';
    return x*f;
}

inline void add(int u,int v){
    to[++cnt]=v,nxt[cnt]=hd[u];
    hd[u]=cnt;return;
}

inline void pushup(int k){
    tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx);
    tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum;
    return;
}

inline void dfs1(int u,int fa,int d){
    dep[u]=d;f[u]=fa;siz[u]=1;
    for(int i=hd[u];i;i=nxt[i]){
        int v=to[i];
        if(v==fa) continue;
        dfs1(v,u,d+1);
        siz[u]+=siz[v];
        if(siz[v]>siz[son[u]])
            son[u]=v;
    }
    return;
}

inline void dfs2(int u,int fa,int tp){
    top[u]=tp;seg[u]=++tot;rnk[tot]=u;
    if(son[u]) dfs2(son[u],u,tp);
    for(int i=hd[u];i;i=nxt[i]){
        int v=to[i];
        if(v==fa || v==son[u]) continue;
        dfs2(v,u,v);
    }
    return;
}

inline void build(int L,int R,int k){
    tr[k].l=L,tr[k].r=R;
    if(L==R){
        tr[k].mx=tr[k].sum=A[rnk[L]];
        return;
    }
    int mid=(L+R)>>1;
    build(L,mid,k<<1);
    build(mid+1,R,k<<1|1);
    pushup(k);return;
}

inline void update(int x,int y,int k){
    if(tr[k].l==tr[k].r){
        tr[k].mx=tr[k].sum=y;
        return;
    }
    int mid=(tr[k].l+tr[k].r)>>1;
    if(x<=mid) update(x,y,k<<1);
    else update(x,y,k<<1|1);
    pushup(k);return;
}

inline int qmx(int L,int R,int k){
    if(L<=tr[k].l && tr[k].r<=R)
        return tr[k].mx;
    int mid=(tr[k].l+tr[k].r)>>1;
    if(L<=mid && R>mid) 
        return max(qmx(L,R,k<<1),qmx(L,R,k<<1|1));
    else if(R<=mid) return qmx(L,R,k<<1);
    else return qmx(L,R,k<<1|1);
}

inline int qsum(int L,int R,int k){
    if(L<=tr[k].l && tr[k].r<=R)
        return tr[k].sum;
    int mid=(tr[k].l+tr[k].r)>>1;
    if(L<=mid && R>mid) 
        return qsum(L,R,k<<1)+qsum(L,R,k<<1|1);
    else if(R<=mid) return qsum(L,R,k<<1);
    else return qsum(L,R,k<<1|1);
}

inline int solve1(int u,int v){
    int ans=-INF;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        ans=max(ans,qmx(seg[top[u]],seg[u],1));
        u=f[top[u]];
    }
    if(dep[u]<dep[v]) swap(u,v);
    ans=max(ans,qmx(seg[v],seg[u],1));
    return ans;
}

inline int solve2(int u,int v){
    int ans=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        ans+=qsum(seg[top[u]],seg[u],1);
        u=f[top[u]];
    }
    if(dep[u]<dep[v]) swap(u,v);
    ans+=qsum(seg[v],seg[u],1);
    return ans;
}

int main(){
    int N=read();
    for(int i=1;i<N;i++){
        int u=read(),v=read();
        add(u,v);add(v,u);
    }
    for(int i=1;i<=N;i++) A[i]=read();
    dfs1(1,0,1);dfs2(1,0,1);build(1,N,1);
    int M=read();
    while(M--){
        cin>>str;int x=read(),y=read();
        if(str[0]=='C') update(seg[x],y,1);
        else if(str[1]=='M') printf("%d\n",solve1(x,y));
        else printf("%d\n",solve2(x,y));
    }
    return 0;
}

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!