和线段树类似,每个结点也要打lazy标记
但是lazy标记和线段树不一样
具体区别在于可持久化后lazy-tag不用往下传递,而是固定在这个区间并不断累加,变成了这个区间固有的性质(有点像分块的标记了)
update就按照这么来
int update(int last,int L,int R,int c,int l,int r){ int now=++size; T[now]=T[last]; if(L<=l && R>=r){ T[now].sum+=(r-l+1)*c; T[now].add+=c; return now; } int mid=l+r>>1; if(L<=mid)T[now].lc=update(T[last].lc,L,R,c,l,mid); if(R>mid)T[now].rc=update(T[last].rc,L,R,c,mid+1,r); pushup(l,r,now); return now; }
查询时由于lazytag固定在区间上。所以向下查询的时候要把上层的lazytag的影响都算上,即递归时传递一个上层区间的 影响值(例如add)
ll query(int now,int L,int R,int add,int l,int r){ if(L<=l && R>=r) return T[now].sum+(ll)add*(r-l+1); int mid=l+r>>1; ll res=0;add+=T[now].add; if(L<=mid)res+=query(T[now].lc,L,R,add,l,mid); if(R>mid)res+=query(T[now].rc,L,R,add,mid+1,r); return res; }
此外还有合并维护时,由于子区间没有收到父区间的影响,所以合并时还要算父区间的lazytag
void pushup(int l,int r,int rt){T[rt].sum=T[T[rt].lc].sum+T[T[rt].rc].sum+T[rt].add*(r-l+1);}
最后是完整代码,其实本题版本回滚时还可以吧size往回滚,以此节省内存
/* 主席树区间更新 */ #include<bits/stdc++.h> using namespace std; #define ll long long #define maxn 100005 ll n,m,a[maxn]; struct Node{int lc,rc;ll sum,add;}T[maxn*25]; int size,rt[maxn]; void pushup(int l,int r,int rt){T[rt].sum=T[T[rt].lc].sum+T[T[rt].rc].sum+T[rt].add*(r-l+1);} int build(int l,int r){ int now=++size; if(l==r){ T[now].lc=T[now].rc=0; T[now].sum=a[l]; return now; } int mid=l+r>>1; T[now].lc=build(l,mid); T[now].rc=build(mid+1,r); pushup(l,r,now); return now; } int update(int last,int L,int R,int c,int l,int r){ int now=++size; T[now]=T[last]; if(L<=l && R>=r){ T[now].sum+=(r-l+1)*c; T[now].add+=c; return now; } int mid=l+r>>1; if(L<=mid)T[now].lc=update(T[last].lc,L,R,c,l,mid); if(R>mid)T[now].rc=update(T[last].rc,L,R,c,mid+1,r); pushup(l,r,now); return now; } ll query(int now,int L,int R,int add,int l,int r){ if(L<=l && R>=r) return T[now].sum+(ll)add*(r-l+1); int mid=l+r>>1; ll res=0;add+=T[now].add; if(L<=mid)res+=query(T[now].lc,L,R,add,l,mid); if(R>mid)res+=query(T[now].rc,L,R,add,mid+1,r); return res; } void init(){ size=0; memset(rt,0,sizeof rt); memset(T,0,sizeof T); } int main(){ while(scanf("%lld%lld",&n,&m)==2){ init(); for(int i=1;i<=n;i++)scanf("%lld",&a[i]); int cur=0,l,r,c;char op[2]; rt[cur]=build(1,n); while(m--){ scanf("%s",op); if(op[0]=='C'){scanf("%d%d%d",&l,&r,&c);rt[++cur]=update(rt[cur-1],l,r,c,1,n);} if(op[0]=='Q'){scanf("%d%d",&l,&r);cout<<query(rt[cur],l,r,0,1,n)<<'\n';} if(op[0]=='H'){ scanf("%d%d%d",&l,&r,&c); cout<<query(rt[c],l,r,0,1,n)<<'\n'; } if(op[0]=='B'){scanf("%d",&c);cur=c;} } // puts(""); } }
来源:https://www.cnblogs.com/zsben991126/p/10764596.html