【模板】重链剖分(树链剖分)

一曲冷凌霜 提交于 2020-02-02 09:48:26

我们知道对一列数进行区间或单点加减,乘除和区间求值等操作可以用线段树或树状数组

那么,如何对带权树上一条路径中的数进行这样的操作呢?

此时就用到了线段树的树上版——树链剖分


 

树链剖分的目的在于把树变成一条线段,以方便区间操作

显然,我们无法让每一条树上路径中所有数在这条线段中相邻,但又不能让它们太分散

于是,我们想寻找一个使区间操作均摊复杂度较小的树-线段映射方法

介绍概念:

  • 重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点;
  • 轻儿子:父亲节点中除了重儿子以外的儿子;
  • 重边:父亲结点和重儿子连成的边;
  • 轻边:父亲节点和轻儿子连成的边;
  • 重链:由多条重边连接而成的路径;
  • 轻链:由多条轻边连接而成的路径;

介绍数组:

fa[i]:点i的父亲

top[i]:点i所在重链的顶点

si[i]:以点i为根的子树的节点数

son[i]:点i的重儿子

dfn[i]:点i的dfs序

dep[i]:点i的深度

 

比如上面这幅图中,用黑线连接的结点都是重结点,其余均是轻结点,

2-11就是重链,2-5就是轻链,用红点标记的就是该结点所在重链的起点,也就是下文提到的top结点,

还有每条边的值其实是进行dfs时的执行序号。

(图和说明来自这位大佬的博客

遍历时,我们使用先序遍历,并先遍历重儿子,再遍历轻儿子

于是,dfn重儿子=dfn父亲+1,即一条重链上从上到下dfn相邻且依次增大

我们按照dfn序将节点的权值放到数列上,然后就可以愉快地对它进行线段树操作了!

树链剖分的两个性质:

1,如果(u, v)是一条轻边,那么size(v) < size(u)/2;

2,从根结点到任意结点的路所经过的轻重链的个数必定都小于logn;

可以证明,树链剖分中一次区间操作的时间复杂度为O((log2n)^2);

 


模板题:【模板】重链剖分(luogu)

Description

题目描述

 

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

 

输入格式

 

第一行包含4个正整数N、M、R、P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。

接下来一行包含N个非负整数,分别依次表示各个节点上初始的数值。

接下来N-1行每行包含两个整数x、y,表示点x和点y之间连有一条边(保证无环且连通)

接下来M行每行包含若干个正整数,每行表示一个操作,格式如下:

操作1: 1 x y z

操作2: 2 x y

操作3: 3 x z

操作4: 4 x

 

输出格式

 

输出包含若干行,分别依次表示每个操作2或操作4所得的结果(对P取模)

Code

#include <cstdio>
#include <cstdlib>
#include <vector>
#include <algorithm>
#define ll long long
using namespace std;
const int N=1e5+10;
ll d[N],P,z;
int re[N],n,m,opt,x,y,r,fa[N],top[N],si[N],son[N],dfn[N],tot,rt,sum,dep[N];
struct node
{
    int l,r,lc,rc;
    ll sum,delay;
}f[N*2];
vector <int> link[N];
void dfs(int u,int f)
{
    fa[u]=f,si[u]=1,dep[u]=dep[f]+1;
    int size=link[u].size();
    for(int i=0;i<size;i++)
    {
        int v=link[u][i];
        if(v==f) continue;
        dfs(v,u),si[u]+=si[v];
        if(son[u]==0 || si[son[u]]<si[v]) son[u]=v;
    }
}
void Dfs(int u)
{
    dfn[u]=++tot,re[tot]=u;
    if(!son[u]) return ;
    top[son[u]]=top[u],Dfs(son[u]);
    int size=link[u].size();
    for(int i=0;i<size;i++)
    {
        int v=link[u][i];
        if(v==fa[u] || v==son[u]) continue;
        top[v]=v,Dfs(v);
    }
}
ll gel(int g)
{
    return f[g].r-f[g].l+1;
}
void push_up(int g)
{
    f[g].sum=f[f[g].lc].sum+f[f[g].rc].sum;
    f[g].sum%=P;
}
void push_down(int g)
{
    if(f[g].delay==0) return ;
    int lc=f[g].lc,rc=f[g].rc;
    f[lc].sum+=gel(lc)*f[g].delay;
    f[lc].delay+=f[g].delay;
    f[lc].delay%=P,f[lc].sum%=P;
    f[rc].sum+=gel(rc)*f[g].delay;
    f[rc].delay+=f[g].delay;
    f[rc].delay%=P,f[rc].sum%=P;
    f[g].delay=0;
}
void build(int &g,int l,int r)
{
    g=++sum;
    f[g].l=l,f[g].r=r;
    if(l==r)
    {
        f[g].sum+=d[re[l]];
        return ;
    }
    int mid=(l+r)>>1;
    build(f[g].lc,l,mid);
    build(f[g].rc,mid+1,r);
    push_up(g);
}
void add(int g,int l,int r,ll k)
{
    if(f[g].l>=l && f[g].r<=r)
    {
        f[g].sum+=gel(g)*k,f[g].delay+=k;
        f[g].sum%=P,f[g].delay%=P;
        return ;
    }
    push_down(g);
    int mid=(f[g].l+f[g].r)>>1;
    if(r<=mid) add(f[g].lc,l,r,k);
    else if(l>mid) add(f[g].rc,l,r,k);
    else add(f[g].lc,l,mid,k),add(f[g].rc,mid+1,r,k);
    push_up(g);
}
ll get(int g,int l,int r)
{
    if(f[g].l>=l && f[g].r<=r)
        return f[g].sum%P;
    push_down(g);
    int mid=(f[g].l+f[g].r)>>1;
    if(r<=mid) return get(f[g].lc,l,r)%P;
    else if(l>mid) return get(f[g].rc,l,r)%P;
    else return (get(f[g].lc,l,mid)+get(f[g].rc,mid+1,r))%P;
}
void Add(int x,int y,ll k)
{
    int px=top[x],py=top[y];
    while(px!=py)
        if(dep[px]>=dep[py])
        {
            add(rt,dfn[px],dfn[x],k);
            x=fa[px],px=top[x];
        }
        else
        {
            add(rt,dfn[py],dfn[y],k);
            y=fa[py],py=top[y];
        }
    if(dfn[x]>dfn[y]) add(rt,dfn[y],dfn[x],k);
    else add(rt,dfn[x],dfn[y],k);
}
ll Get(int x,int y)
{
    ll ans=0;
    int px=top[x],py=top[y];
    while(px!=py)
        if(dep[px]>=dep[py])
        {
            ans+=get(rt,dfn[px],dfn[x]),ans%=P;
            x=fa[px],px=top[x];
        }
        else
        {
            ans+=get(rt,dfn[py],dfn[y]),ans%=P;
            y=fa[py],py=top[y];
        }
    if(dfn[x]>dfn[y]) ans+=get(rt,dfn[y],dfn[x]),ans%=P;
    else ans+=get(rt,dfn[x],dfn[y]),ans%=P;
    return ans%P;
}
int main()
{
    scanf("%d%d%d%lld",&n,&m,&r,&P);
    for(int i=1;i<=n;i++)
        scanf("%lld",&d[i]);
    for(int i=1;i<n;i++)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        link[u].push_back(v);
        link[v].push_back(u);
    }
    dfs(r,0);
    top[r]=r,Dfs(r);
    build(rt,1,tot);
//    1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
//    2 x y 表示求树从x到y结点最短路径上所有节点的值之和
//    3 x z 表示将以x为根节点的子树内所有节点值都加上z
//    4 x 表示求以x为根节点的子树内所有节点值之和
    while(m--)
    {
        scanf("%d",&opt);
        if(opt==1)
        {
            scanf("%d%d%lld",&x,&y,&z);
            z%=P;
            if(z==0) continue;
            Add(x,y,z);    
        }
        else if(opt==2)
        {
            scanf("%d%d",&x,&y);
            printf("%lld\n",Get(x,y));
        }
        else if(opt==3)
        {
            scanf("%d%lld",&x,&z);
            z%=P;
            if(z==0) continue;
            add(rt,dfn[x],dfn[x]+si[x]-1,z);
        }
        else if(opt==4)
        {
            scanf("%d",&x);
            printf("%lld\n",get(rt,dfn[x],dfn[x]+si[x]-1));
        }
    }
    return 0;
}

 

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!