树链剖分模板

坚强是说给别人听的谎言 提交于 2020-03-01 17:44:12

树剖例题

然后发现可以替代LCA中查询两点距离,特意来保存下代码模板

我代码中qulen函数 就是查询两点间的距离。

学习博客

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e5+10;
vector<int>G[N];
int a[N];
int sum[2][4*N];
int n,q;
int sz[N],son[N],f[N],d[N],top[N],id[N],cnt;
void dfs1(int u,int fat,int dep)
{
    f[u]=fat;
    d[u]=dep;
    sz[u]=1;
    for(int v:G[u])
    {
        if(v==fat) continue;
        dfs1(v,u,dep+1);
        sz[u]+=sz[v];
        if(sz[v]>sz[son[u]]) son[u]=v;
    }
}
void dfs2(int u,int t)
{
    top[u]=t;
    id[u]=++cnt;
    if(!son[u]) return ;
    dfs2(son[u],t);
    for(int v:G[u])
    {
        if(v==son[u]||v==f[u]) continue;
        dfs2(v,v);
    }
}
void up(int id,int l,int r,int pos,int x,int ty)
{
    if(l>r) return ;
    if(l==r){
        sum[ty][id]=x;
        return ;
    }
    int mid=l+r>>1;
    if(pos<=mid) up(id<<1,l,mid,pos,x,ty);
    else up(id<<1|1,mid+1,r,pos,x,ty);
    sum[ty][id]=(sum[ty][id<<1]^sum[ty][id<<1|1]);
}
int qu(int id,int l,int r,int ql,int qr,int ty){
    if(ql<=l&&r<=qr){
        return sum[ty][id];
    }
    int mid=l+r>>1;
    int res=0;
    if(ql<=mid) res^=qu(id<<1,l,mid,ql,qr,ty);
    if(qr>mid) res^=qu(id<<1|1,mid+1,r,ql,qr,ty);
    return res;
}
int getsum(int x,int y,int ty)
{
    int ans=0,fx=top[x],fy=top[y];
    while(fx!=fy){
        if(d[fx]>d[fy]){
            ans^=qu(1,1,n,id[fx],id[x],ty);
            x=f[fx],fx=top[x];
        }
        else{
            ans^=qu(1,1,n,id[fy],id[y],ty);
            y=f[fy],fy=top[y];
        }
    }
    if(id[x]>id[y]) ans^=qu(1,1,n,id[y],id[x],ty);
    else  ans^=qu(1,1,n,id[x],id[y],ty);
    return ans;
}
int qulen(int x,int y)
{
    int ans=0,fx=top[x],fy=top[y];
    while(fx!=fy){
        if(d[fx]>d[fy]) {
            ans+=d[x]-d[fx]+1;
            x=f[fx],fx=top[x];
        }
        else {
            ans+=d[y]-d[fy]+1;
            y=f[fy],fy=top[y];
        }
    }
    if(d[x]>d[y])ans+=d[x]-d[y]+1;
    else ans+=d[y]-d[x]+1;
    return ans;
}
int main()
{
    scanf("%d%d",&n,&q);
	for(int i=1;i<=n;++i) {
        scanf("%d",&a[i]);
	}
	for(int i=1;i<=n-1;++i)
	{
	    int u,v;
        scanf("%d%d",&u,&v);
	    G[u].push_back(v);
	    G[v].push_back(u);
	}
	dfs1(1,0,1);
	dfs2(1,1);

	for(int i=1;i<=n;++i){
        up(1,1,n,id[i],a[i],d[i]%2);
	}

	while(q--)
    {
        int ty,u,v;
        scanf("%d%d%d",&ty,&u,&v);
        if(ty==1){
            up(1,1,n,id[u],v,d[u]%2);
        }
        else{
            if(d[u]%2!=d[v]%2){
                int ans1=getsum(u,v,d[u]%2);
                int ans2=getsum(u,v,(d[u]%2)^1);
                printf("%d\n",ans1^ans2);
            }
            else{
                printf("%d\n",getsum(u,v,(d[u]%2)^1));
            }
        }
    }
    return 0;
}

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!