题目链接
做法
\[ dep(x) + dep(y) - dep(LCA(x, y)) - dep'(LCA'(x, y))\\\\ = \frac{1}{2} (dep(x) + dep(y) - 2dep(LCA(x, y)) + dep(x) + dep(y) - 2dep'(LCA'(x, y)))\\\\ = \frac{1}{2}(dis(x, y) + dep(x) + dep(y) - 2dep'(LCA'(x, y))) \]
考虑对第一棵树边分治。设当前分治重心为 $ U, V $ ,选择 $ U $ 侧节点 $ X $ ,选择 $ V $ 侧节点 $ Y $ ,则令 $ X $ 为一类节点,贡献为 $ e1(X) = dep(X) + dis(V, X) $ ;令 $ Y $ 为二类节点,贡献为 $ e2(Y) = dep(Y) + dis(V, Y) $ 。枚举第二颗树的 $ LCA $ ,设 $ X, Y $ 在第二颗树中的 $ LCA $ 为 $ lca $ ,则对答案的贡献为 $ \frac{1}{2}(e1(X) + e2(Y) - 2dep(lca)) $ 。由于需要正确的时间复杂度,所以需要对第二颗树建虚树进行 $ DP $ 。
不优秀的实现会导致时间复杂度为 $ O(n \log^2 n) $ ,由于每次建虚树的时间复杂度应为 $ O(k) $ ,所以需要用 $ RMQ $ 实现 $ O(1) $ 的 $ LCA $ ;另外每次建虚树需要将点排序,将排序放在分治之前,然后按照分治将数列分成两段,每次建虚树直接调用数组(或者先分治下去再归并排序然后建虚树)。更改后时间复杂度为 $ O(n \log n) $ 。
注意 $ x $ 可以与 $ y $ 相同,而边分治未考虑这一点,所以还要考虑 $ x = y $ 的情况。
#include <bits/stdc++.h> #define rep(i, a, b) for(int i = (a); i <= (b); i++) #define per(i, a, b) for(int i = (a); i >= (b); i--) #define pb push_back #define mp make_pair #define fst first #define snd second using namespace std; typedef long long ll; typedef pair<int, ll> pil; const ll INF = 1e17; const int N = 800010; int n, m; ll ans = -INF, w[N], fw[N]; vector<pil> e1[N], e2[N]; vector<int> E[N]; int ar[N], len; int dep[N], dfn[N], idx, a[N], st[20][N], lg[N]; ll Dep[N]; int cnt = 1, to[N + N], nxt[N + N], hed[N]; ll val[N + N]; bool used[N + N]; int size, rte, mn, sz[N], tot; ll value[N], f1[N], f2[N]; int flag[N], sta[N], top; template<typename T> void gi(T &x) { x = 0; register char c = getchar(), pre = 0; for(; c < '0' || c > '9'; pre = c, c = getchar()); for(; c >= '0' && c <= '9'; c = getchar()) x = x * 10ll + (c ^ 48); if(pre == '-') x = -x; } inline void addedge(int x, int y, ll z) { to[++cnt] = y, nxt[cnt] = hed[x], hed[x] = cnt, val[cnt] = z; to[++cnt] = x, nxt[cnt] = hed[y], hed[y] = cnt, val[cnt] = z; } inline bool cmp(const int &x, const int &y) { return dfn[x] < dfn[y]; } void getdfn(int u, int ff) { dfn[u] = ++idx, a[idx] = u, dep[u] = dep[ff] + 1; for(auto v : e2[u]) if(v.fst != ff) Dep[v.fst] = Dep[u] + v.snd, getdfn(v.fst, u), a[++idx] = u; } inline int LCA(int x, int y) { x = dfn[x], y = dfn[y]; if(x > y) swap(x, y); int k = y - x + 1; return dep[st[lg[k]][x]] <= dep[st[lg[k]][y - (1 << lg[k]) + 1]] ? st[lg[k]][x] : st[lg[k]][y - (1 << lg[k]) + 1]; } int build(int l, int r) { if(l > r) return 0; if(l == r) return ar[l]; int mid = (l + r) >> 1, u = ++m, ls = build(l, mid), rs = build(mid + 1, r); if(ls) addedge(u, ls, fw[ls]); if(rs) addedge(u, rs, fw[rs]); return u; } void rebuild(int u, int ff) { len = 0; for(auto v : e1[u]) if(v.fst != ff) ar[++len] = v.fst, fw[v.fst] = v.snd; int mid = (1 + len) >> 1, ls = build(1, mid), rs = build(mid + 1, len); if(ls) addedge(u, ls, fw[ls]); if(rs) addedge(u, rs, fw[rs]); for(auto v : e1[u]) if(v.fst != ff) w[v.fst] = w[u] + v.snd, rebuild(v.fst, u); } void getrt(int u, int ff, int ed) { sz[u] = 1; for(int i = hed[u]; i; i = nxt[i]) if(to[i] != ff && !used[i]) getrt(to[i], u, i), sz[u] += sz[to[i]]; if(abs(size - 2 * sz[u]) < mn) mn = abs(size - 2 * sz[u]), rte = ed; } void Find(int u, int ff, int opt, ll d) { if(u <= n) value[u] = d + w[u], flag[u] = opt; for(int i = hed[u]; i; i = nxt[i]) if(to[i] != ff && !used[i]) Find(to[i], u, opt, d + val[i]); } void Dfs(int u) { f1[u] = f2[u] = -INF; if(flag[u] == 1) f1[u] = max(f1[u], value[u]); if(flag[u] == 2) f2[u] = max(f2[u], value[u]); for(auto v : E[u]) { Dfs(v); ans = max(ans, max(f1[u] + f2[v], f2[u] + f1[v]) - Dep[u] - Dep[u]); f1[u] = max(f1[u], f1[v]), f2[u] = max(f2[u], f2[v]); } E[u].clear(), flag[u] = 0; } void Solve(vector<int> p) { sta[top = 1] = 1; for(auto v : p) { if(v == 1) continue; int lca = LCA(v, sta[top]); if(lca == sta[top]) { sta[++top] = v; continue; } for(; top > 1 && dep[sta[top - 1]] >= dep[lca]; --top) E[sta[top - 1]].pb(sta[top]); if(lca != sta[top]) E[lca].pb(sta[top]), sta[top] = lca; sta[++top] = v; } for(; top > 1; --top) E[sta[top - 1]].pb(sta[top]); Dfs(1); } void solve(int u, vector<int> p) { if(u == -1 || used[u]) return ; used[u] = used[u ^ 1] = 1; int rt1 = to[u], rt2 = to[u ^ 1], t1, t2; Find(rt1, rt2, 1, val[u]), Find(rt2, rt1, 2, 0); vector<int> ls, rs; for(auto v : p) flag[v] == 1 ? ls.pb(v) : rs.pb(v); Solve(p); size = ls.size(), mn = m + 1, getrt(rt1, rt2, -1), solve(rte, ls); size = rs.size(), mn = m + 1, getrt(rt2, rt1, -1), solve(rte, rs); } int main() { gi(n), m = n; rep(i, 2, n) { int x, y; ll z; gi(x), gi(y), gi(z), e1[x].pb(mp(y, z)), e1[y].pb(mp(x, z)); } rep(i, 2, n) { int x, y; ll z; gi(x), gi(y), gi(z), e2[x].pb(mp(y, z)), e2[y].pb(mp(x, z)); } getdfn(1, 0); rep(i, 1, idx) st[0][i] = a[i]; lg[0] = -1; rep(i, 1, idx) lg[i] = lg[i >> 1] + 1; for(int j = 1; (1 << j) <= idx; j++) for(int i = 1; i <= idx - (1 << j) + 1; i++) st[j][i] = dep[st[j - 1][i]] <= dep[st[j - 1][i + (1 << j - 1)]] ? st[j - 1][i] : st[j - 1][i + (1 << j - 1)]; rebuild(1, 0); vector<int> p; rep(i, 1, n) p.pb(i); sort(p.begin(), p.end(), cmp); size = m, mn = m + 1, getrt(1, 0, -1), solve(rte, p), ans /= 2; rep(i, 1, n) ans = max(ans, w[i] - Dep[i]); printf("%lld\n", ans); return 0; }