题目链接:https://www.luogu.org/problem/P3313
这道题目就是树链剖分+线段树动态开点。
然后做这道题目之前我们先来看一道不考虑树链剖分之后完全相同的线段树动态开点的题目:
https://www.cnblogs.com/codedecision/p/11791200.html
然后你就会发现这就是树链剖分+上题的线段树处理。
然后这道题目就变得很简单。
实现代码如下:
#include <bits/stdc++.h> using namespace std; #define INF (1<<29) const int maxn = 100010; int fa[maxn], dep[maxn], size[maxn], son[maxn], top[maxn], seg[maxn], seg_cnt, rev[maxn]; vector<int> g[maxn]; void dfs1(int u, int p) { size[u] = 1; for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) { int v = (*it); if (v == p) continue; fa[v] = u; dep[v] = dep[u] + 1; dfs1(v, u); size[u] += size[v]; if (size[v] >size[son[u]]) son[u] = v; } } void dfs2(int u, int tp) { seg[u] = ++seg_cnt; rev[seg_cnt] = u; top[u] = tp; if (son[u]) dfs2(son[u], tp); for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) { int v = (*it); if (v == fa[u] || v == son[u]) continue; dfs2(v, v); } } struct Tnode { int l, r, sumw, maxw; Tnode *lson, *rson; Tnode(int _l, int _r, int _sumw, int _maxw) { l = _l; r = _r; sumw = _sumw; maxw = _maxw; lson = rson = NULL; } } *root[maxn]; int n, q, w[maxn], c[maxn]; void push_up(Tnode *rt) { rt->sumw = rt->maxw = 0; if (rt->lson != NULL) { rt->sumw += rt->lson->sumw; rt->maxw = max(rt->maxw, rt->lson->maxw); } if (rt->rson != NULL) { rt->sumw += rt->rson->sumw; rt->maxw = max(rt->maxw, rt->rson->maxw); } } void update(int p, int v, Tnode *rt) { int l = rt->l, r = rt->r, mid = (rt->l + rt->r) / 2; if (l == r) { rt->sumw = rt->maxw = v; return; } if (p <= mid) { if (rt->lson == NULL) rt->lson = new Tnode(l, mid, 0, 0); update(p, v, rt->lson); } else { if (rt->rson == NULL) rt->rson = new Tnode(mid+1, r, 0, 0); update(p, v, rt->rson); } push_up(rt); } int query_sum(int L, int R, Tnode *rt) { int l = rt->l, r = rt->r, mid = (rt->l + rt->r) / 2; if (L <= l && r <= R) return rt->sumw; int tmp = 0; if (L <= mid && rt->lson != NULL) tmp += query_sum(L, R, rt->lson); if (R > mid && rt->rson != NULL) tmp += query_sum(L, R, rt->rson); return tmp; } int query_max(int L, int R, Tnode *rt) { int l = rt->l, r = rt->r, mid = (rt->l + rt->r) / 2; if (L <= l && r <= R) return rt->maxw; int tmp = 0; if (L <= mid && rt->lson != NULL) tmp = max(tmp, query_max(L, R, rt->lson)); if (R > mid && rt->rson != NULL) tmp = max(tmp, query_max(L, R, rt->rson)); return tmp; } void init() { for (int i = 1; i < maxn; i ++) root[i] = new Tnode(1, n, 0, 0); } int ask_sum(int u, int v) { int res = 0; Tnode* rt = root[c[u]]; while (top[u] != top[v]) { if (dep[top[u]] < dep[top[v]]) swap(u, v); res += query_sum(seg[top[u]], seg[u], rt); u = fa[top[u]]; } if (dep[u] < dep[v]) swap(u, v); res += query_sum(seg[v], seg[u], rt); return res; } int ask_max(int u, int v) { int res = -INF; Tnode* rt = root[c[u]]; while (top[u] != top[v]) { if (dep[top[u]] < dep[top[v]]) swap(u, v); res = max(res, query_max(seg[top[u]], seg[u], rt)); u = fa[top[u]]; } if (dep[u] < dep[v]) swap(u, v); res = max(res, query_max(seg[v], seg[u], rt)); return res; } int x, y; string op; int main() { cin >> n >> q; for (int i = 1; i <= n; i ++) { cin >> w[i] >> c[i]; } for (int i = 1; i < n; i ++) { int x, y; cin >> x >> y; g[x].push_back(y); g[y].push_back(x); } dep[1] = fa[1] = 1; dfs1(1, -1); dfs2(1, 1); init(); for (int i = 1; i <= n; i ++) { update(seg[i], w[i], root[c[i]]); } while (q --) { cin >> op >> x >> y; if (op == "CC") { update(seg[x], 0, root[c[x]]); c[x] = y; update(seg[x], w[x], root[c[x]]); } else if (op == "CW") { update(seg[x], y, root[c[x]]); w[x] = y; } else if (op == "QS") { cout << ask_sum(x, y) << endl; } else { // QM cout << ask_max(x, y) << endl; } } return 0; }