https://loj.ac/problem/10141
题目描述
给出一棵树,维护两个操作:\(①\)把\(a\)到\(b\)的路径上的节点全部染成颜色\(c\);\(②\)询问节点\(a\)到节点\(b\)的路径上的颜色段的个数(连续相同颜色算同一个颜色段)。
思路
树上的修改和询问操作,很容易想到树链剖分,不过这里的重点是线段树的维护。我们考虑对于一段序列而言,当我们进行合并时中间这段会接起来,所以我们需要维护每个区间的最左端点的颜色、最右端点的颜色和区间内的颜色段数,合并时维护这三个值,如果合并的左区间的右端点颜色和右区间的左端点颜色相同就要把颜色段和加起来减\(1\)。
代码
#include <bits/stdc++.h> using namespace std; const int N=1e5+10; int nxt[N<<1],to[N<<1],tot,head[N]; void add_edge(int x,int y) { nxt[++tot]=head[x]; head[x]=tot; to[tot]=y; } int siz[N],fa[N],dep[N],son[N],top[N]; int seg[N<<2],rev[N<<2],sum[N<<2],lflag[N<<2],rflag[N<<2],lazy[N<<2]; int ans,col[N]; void dfs1(int u,int father) { siz[u]=1;fa[u]=father; dep[u]=dep[father]+1; for(int i=head[u];i;i=nxt[i]) { int v=to[i]; if(v==father)continue ; dfs1(v,u); siz[u]+=siz[v]; if(siz[v]>siz[son[u]])son[u]=v; } } void dfs2(int u,int father) { if(son[u]) { seg[son[u]]=++seg[0]; rev[seg[0]]=son[u]; top[son[u]]=top[u]; dfs2(son[u],u); } for(int i=head[u];i;i=nxt[i]) { int v=to[i]; if(top[v])continue ; seg[v]=++seg[0]; rev[seg[0]]=v; top[v]=v; dfs2(v,u); } } void pushup(int k) { sum[k]=sum[k<<1]+sum[k<<1|1]; if(rflag[k<<1]==lflag[k<<1|1])sum[k]--; lflag[k]=lflag[k<<1]; rflag[k]=rflag[k<<1|1]; } void build(int k,int l,int r) { lazy[k]=0; if(l==r) { sum[k]=1; rflag[k]=lflag[k]=col[rev[l]]; return ; } int mid=l+r>>1; build(k<<1,l,mid);build(k<<1|1,mid+1,r); pushup(k); } void pushdown(int k) { if(!lazy[k])return ; int x=lazy[k]; sum[k<<1]=sum[k<<1|1]=1; lazy[k<<1]=lazy[k<<1|1]=x; lflag[k<<1]=lflag[k<<1|1]=rflag[k<<1]=rflag[k<<1|1]=x; lazy[k]=0; } void change(int k,int l,int r,int x,int y,int val) { if(r<x||l>y)return ; if(l>=x&&r<=y) { sum[k]=1;lazy[k]=val; rflag[k]=lflag[k]=val; return ; } int mid=l+r>>1; pushdown(k); if(x<=mid)change(k<<1,l,mid,x,y,val); if(y>mid)change(k<<1|1,mid+1,r,x,y,val); pushup(k); } void query(int k,int l,int r,int x,int y) { if(r<x||l>y)return; if(l>=x&&r<=y) { ans+=sum[k]; return ; } int mid=l+r>>1; pushdown(k); if(x<=mid)query(k<<1,l,mid,x,y); if(y>mid)query(k<<1|1,mid+1,r,x,y); if(x<=mid&&y>mid&&rflag[k<<1]==lflag[k<<1|1])ans--; } int get(int k,int l,int r,int pos) { if(l==r) return lflag[k]; pushdown(k); int mid=l+r>>1; if(mid>=pos)return get(k<<1,l,mid,pos); else return get(k<<1|1,mid+1,r,pos); } void ask(int x,int y) { int fx=top[x],fy=top[y]; while(fx!=fy) { if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy); query(1,1,seg[0],seg[fx],seg[x]); if(get(1,1,seg[0],seg[fx])==get(1,1,seg[0],seg[fa[fx]]))ans--; x=fa[fx];fx=top[x]; } if(dep[x]>dep[y])swap(x,y); query(1,1,seg[0],seg[x],seg[y]); } void add(int x,int y,int v) { int fx=top[x],fy=top[y]; while(fx!=fy) { if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy); change(1,1,seg[0],seg[fx],seg[x],v); x=fa[fx];fx=top[x]; } if(dep[x]>dep[y])swap(x,y); change(1,1,seg[0],seg[x],seg[y],v); } int read() { int res=0,w=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();} while(ch>='0'&&ch<='9'){res=(res<<3)+(res<<1)+(ch^48);ch=getchar();} return res*w; } void write(int x) { if(x<0){putchar('-');x=-x;} if(x>9)write(x/10); putchar(x%10+'0'); } void writeln(int x) { write(x); putchar('\n'); } int main() { int n=read(),m=read(); for(int i=1;i<=n;i++) col[i]=read(); for(int i=1;i<n;i++) { int x=read(),y=read(); add_edge(x,y);add_edge(y,x); } dfs1(1,0); seg[0]=seg[1]=top[1]=rev[1]=1; dfs2(1,0); build(1,1,seg[0]); // for(int i=1;i<=n;i++) // printf("%d %d\n",i,seg[i]); /* printf("\n"); for(int i=1;i<=13;i++) printf("%d %d\n",i,sum[i]); printf("\n");*/ while(m--) { char op; scanf(" %c",&op); if(op=='Q') { int x=read(),y=read(); ans=0; ask(x,y); writeln(ans); } else { int x=read(),y=read(),v=read(); add(x,y,v); } /* for(int i=1;i<=n;i++) for(int j=1;j<=n;j++) { ans=0; ask(i,j); printf("%d %d %d\n",i,j,ans); }*/ } return 0; }