题意
给一颗带边权的树,有两种操作
- \(C~e_i~w_i\),将第\(e_i\)条边的边权改为\(w_i\)。
- \(Q~v_i\),询问距\(v_i\)点最远的点的距离。
分析
官方题解做法:动态维护直径,然后再支持询问两个点的距离,后者可以 dfs 序 + lca + 树状数组。动态维护直径可以用点分治(点分树),具体做法是,考虑过分治中心的最长路径,我们只需要查询分别以分治中心的每个儿子为根,所在子树的最长链,从中再找到最长和次长即可,这个星状图可以用 set 维护。每个子树则可以使用 dfs 序+线段树维护。复杂度 O(nlog2n)。
其实我们不用考虑直径,同样维护以分治中心的每个儿子为根,所在子树的最长链,用个可删堆\(ch[x]\)维护\(x\)的每个儿子的子树的最长链,查询距\(v\)点最远的距离时,有两种情况
- \(v\)点作为分治中心,此时答案为\(ch[v]\)。
- 跳\(v\)的点分树中的祖先,先将\(ch[fa[v]]\)中包括\(v\)的子树的最长链删去,答案为\(dis(fa[v],v)+ch[fa[v]]\),这个\(dis\)可以用树状数组+\(lca\)+dfs序维护。
所有答案取最大值就是最终答案了。每个子树中的链我是用dfs序+动态开点线段树维护的,细节很多,因为网上没这题点分树做法的博客,自己也刚学,很多处理细节都是自己YY的...写的很繁,感觉只有我能看懂(
Code
#include<bits/stdc++.h> #define fi first #define se second #define pb push_back #define lson l,mid,p<<1 #define rson mid+1,r,p<<1|1 #define ll long long using namespace std; const int inf=1e9; const int mod=1e9+7; const int maxn=1e5+10; typedef pair<int,int> pii; int n,q; int sz[maxn],mxp[maxn],vis[maxn],f[maxn],sum,rt; ll tr[maxn]; int e[maxn]; int id[maxn]; vector<pii>g[maxn]; vector<int>son[maxn]; unordered_map<int,int>in[maxn],out[maxn],pd[maxn]; void add(int x,ll k){ while(x<=n) tr[x]+=k,x+=x&-x; } ll dor(int x){ ll ret=0; while(x) ret+=tr[x],x-=x&-x; return ret; } ll a[maxn*140],tag[maxn*140]; int ls[maxn*140],rs[maxn*140],rtt[maxn],tot; void pdu(int p,ll k){a[p]+=k,tag[p]+=k;} void up(int dl,int dr,int l,int r,int &p,ll k){ if(dl>dr||!dl) return; a[++tot]=a[p],ls[tot]=ls[p],tag[tot]=tag[p],rs[tot]=rs[p],p=tot; if(l==dl&&r==dr){ a[p]+=k;tag[p]+=k; return; }int mid=l+r>>1; pdu(ls[p],tag[p]);pdu(rs[p],tag[p]);tag[p]=0; if(dr<=mid) up(dl,dr,l,mid,ls[p],k); else if(dl>mid) up(dl,dr,mid+1,r,rs[p],k); else up(dl,mid,l,mid,ls[p],k),up(mid+1,dr,mid+1,r,rs[p],k); a[p]=max(a[ls[p]],a[rs[p]]); } ll qy(int dl,int dr,int l,int r,int p){ if(l==dl&&r==dr) return a[p]; int mid=l+r>>1; pdu(ls[p],tag[p]);pdu(rs[p],tag[p]);tag[p]=0; if(dr<=mid) return qy(dl,dr,l,mid,ls[p]); else if(dl>mid) return qy(dl,dr,mid+1,r,rs[p]); else return max(qy(dl,mid,l,mid,ls[p]),qy(mid+1,dr,mid+1,r,rs[p])); } struct heap { priority_queue<ll> A, B; // heap=A-B void insert(ll x) { A.push(x); } void erase(ll x) { B.push(x); } ll top() { while (!B.empty() && A.top() == B.top()) A.pop(), B.pop(); return A.top(); } void pop() { while (!B.empty() && A.top() == B.top()) A.pop(), B.pop(); A.pop(); } ll top2() { ll t = top(), ret; pop(); ret = top(); A.push(t); return ret; } int size() { return A.size() - B.size(); } }ch[maxn]; struct LCA{ int sz[maxn],d[maxn],f[maxn],top[maxn],son[maxn],in[maxn],out[maxn],p[maxn],id[maxn],num; ll dist[maxn]; void dfs1(int u){ sz[u]=1;d[u]=d[f[u]]+1; for(pii x:g[u]){ if(x.fi==f[u]) continue; f[x.fi]=u;dist[x.fi]=dist[u]+e[x.se],id[x.se]=x.fi; dfs1(x.fi); sz[u]+=sz[x.fi]; if(sz[x.fi]>sz[son[u]]) son[u]=x.fi; } } void dfs2(int u,int t){ top[u]=t;in[u]=++num;p[num]=u; if(son[u]) dfs2(son[u],t); for(pii x:g[u]){ if(x.fi==f[u]||x.fi==son[u]) continue; dfs2(x.fi,x.fi); } out[u]=num; } int lca(int x,int y){ while(top[x]!=top[y]){ if(d[top[x]]<d[top[y]]) swap(x,y); x=f[top[x]]; } if(d[x]<d[y]) swap(x,y); return y; } }L; ll dis(int x,int y){ return dor(L.in[x])+dor(L.in[y])-2*dor(L.in[L.lca(x,y)]); } void getrt(int u,int fa){ sz[u]=1;mxp[u]=0; for(pii x:g[u]){ if(x.fi==fa||vis[x.fi]) continue; getrt(x.fi,u); sz[u]+=sz[x.fi]; mxp[u]=max(mxp[u],sz[x.fi]); } mxp[u]=max(mxp[u],sum-sz[u]); if(mxp[u]<mxp[rt]) rt=u; } void calc(int u,int fa,int fart){ son[rt].pb(u); in[rt][u]=son[rt].size(); if(fart!=-1) up(in[f[rt]][u],in[f[rt]][u],1,son[fart].size(),rtt[rt],dis(fart,u)); for(pii x:g[u]){ if(x.fi==fa||vis[x.fi]) continue; pd[rt][x.se]=x.fi;id[x.se]=x.fi; calc(x.fi,u,fart); } out[rt][u]=son[rt].size(); } void solve(int u){ vis[u]=1; for(pii x:g[u]){ if(vis[x.fi]) continue; sum=sz[x.fi];mxp[rt=0]=inf; getrt(x.fi,0);getrt(rt,0); f[rt]=u; calc(rt,0,u); ch[u].insert(a[rtt[rt]]); solve(rt); } ch[u].insert(0); } int main(){ //ios::sync_with_stdio(false); //freopen("in","r",stdin); scanf("%d",&n); for(int i=1,x,y;i<n;i++){ scanf("%d%d%d",&x,&y,&e[i]); g[x].pb(pii(y,i));g[y].pb(pii(x,i)); } L.dfs1(1);L.dfs2(1,1); for(int i=1;i<=n;i++) add(L.in[i],L.dist[i]),add(L.in[i]+1,-L.dist[i]); sum=mxp[rt]=n; getrt(1,0);calc(rt,0,-1);solve(rt); scanf("%d",&q); while(q--){ char c[2]; int x,j; ll y; scanf("%s",c); if(c[0]=='C'){ scanf("%d%lld",&j,&y); int gt=L.id[j]; ll ret=e[j]; x=id[j]; add(L.in[gt],y-ret),add(L.out[gt]+1,ret-y); for(int i=x;f[i];i=f[i]){ ch[f[i]].erase(a[rtt[i]]); up(in[f[i]][pd[f[i]][j]],out[f[i]][pd[f[i]][j]],1,son[f[i]].size(),rtt[i],y-ret); ch[f[i]].insert(a[rtt[i]]); } e[j]=y; }else{ scanf("%d",&x); ll ans=ch[x].top(); for(int i=x;f[i];i=f[i]){ ch[f[i]].erase(a[rtt[i]]); ans=max(ans,dis(f[i],x)+ch[f[i]].top()); ch[f[i]].insert(a[rtt[i]]); } printf("%lld\n",ans); } } return 0; }