太久没写博客了,过来水一发。
题目链接:洛谷
首先我们想到,考虑每个叶节点的权值为根节点权值的概率。首先要将叶节点权值离散化。
假设现在是$x$节点,令$f_i,g_i$分别表示左/右节点的权值$=i$的概率。
若$w_x$来自于左儿子,则
$$P(w_x=i)=f_i*(p_x*\sum_{j=1}^{i-1}g_j+(1-p)*\sum_{j=i+1}^mg_j)$$
右儿子也是一样的。
所以在转移的时候需要顺便维护$f,g$的前/后缀和。
但是我们发现这样直接跑是$O(n^2)$的,肯定不行,但是每个节点的所有dp值都只依赖于两个儿子,而且区间乘法是可以使用lazy_tag的,所以可以使用线段树合并。
(等会儿,好像之前并没有写过。。。)
线段树合并就是对于值域线段树,合并的时候如果两棵树都有这个节点,那么就递归下去,否则直接按照上面的式子转移。
$f,g$的前/后缀和也可以放在参数里面顺便维护了。
1 #include<bits/stdc++.h> 2 #define Rint register int 3 using namespace std; 4 typedef long long LL; 5 const int N = 300003, mod = 998244353, inv = 796898467; 6 int n, v[N], tot, p[N], fa[N], head[N], to[N], nxt[N]; 7 inline void add(int a, int b){ 8 static int cnt = 0; 9 to[++ cnt] = b; nxt[cnt] = head[a]; head[a] = cnt; 10 } 11 int root[N], ls[N << 5], rs[N << 5], seg[N << 5], tag[N << 5], cnt, ans; 12 inline void pushdown(int x){ 13 if(x && tag[x] != 1){ 14 if(ls[x]){ 15 seg[ls[x]] = (LL) seg[ls[x]] * tag[x] % mod; 16 tag[ls[x]] = (LL) tag[ls[x]] * tag[x] % mod; 17 } 18 if(rs[x]){ 19 seg[rs[x]] = (LL) seg[rs[x]] * tag[x] % mod; 20 tag[rs[x]] = (LL) tag[rs[x]] * tag[x] % mod; 21 } 22 tag[x] = 1; 23 } 24 } 25 inline void change(int &x, int L, int R, int pos){ 26 if(!x) tag[x = ++ cnt] = 1; 27 pushdown(x); 28 ++ seg[x]; 29 if(seg[x] >= mod) seg[x] = 0; 30 if(L == R) return; 31 int mid = L + R >> 1; 32 if(pos <= mid) change(ls[x], L, mid, pos); 33 else change(rs[x], mid + 1, R, pos); 34 } 35 inline int merge(int lx, int rx, int L, int R, int pl, int pr, int sl, int sr, int P){ 36 if(!lx && !rx) return 0; 37 int now = ++ cnt, mid = L + R >> 1; tag[now] = 1; 38 pushdown(lx); pushdown(rx); 39 if(!lx){ 40 int v = ((LL) P * sl + (mod + 1ll - P) * sr) % mod; 41 seg[now] = (LL) seg[rx] * v % mod; 42 tag[now] = (LL) tag[rx] * v % mod; 43 ls[now] = ls[rx]; rs[now] = rs[rx]; 44 return now; 45 } 46 if(!rx){ 47 int v = ((LL) P * pl + (mod + 1ll - P) * pr) % mod; 48 seg[now] = (LL) seg[lx] * v % mod; 49 tag[now] = (LL) tag[lx] * v % mod; 50 ls[now] = ls[lx]; rs[now] = rs[lx]; 51 return now; 52 } 53 ls[now] = merge(ls[lx], ls[rx], L, mid, pl, (pr + seg[rs[rx]]) % mod, sl, (sr + seg[rs[lx]]) % mod, P); 54 rs[now] = merge(rs[lx], rs[rx], mid + 1, R, (pl + seg[ls[rx]]) % mod, pr, (sl + seg[ls[lx]]) % mod, sr, P); 55 seg[now] = (seg[ls[now]] + seg[rs[now]]) % mod; 56 return now; 57 } 58 inline void getans(int x, int L, int R){ 59 pushdown(x); 60 if(L == R){ 61 ans = (ans + (LL) seg[x] * seg[x] % mod * v[L] % mod * L % mod) % mod; 62 return; 63 } 64 int mid = L + R >> 1; 65 getans(ls[x], L, mid); 66 getans(rs[x], mid + 1, R); 67 } 68 inline void dfs(int x){ 69 if(!head[x]){ 70 change(root[x], 1, n, p[x]); 71 return; 72 } 73 for(Rint i = head[x];i;i = nxt[i]){ 74 dfs(to[i]); 75 if(!root[x]) root[x] = root[to[i]]; 76 else root[x] = merge(root[x], root[to[i]], 1, n, 0, 0, 0, 0, p[x]); 77 } 78 } 79 int main(){ 80 scanf("%d", &n); 81 for(Rint i = 1;i <= n;i ++){ 82 scanf("%d", fa + i); 83 if(fa[i]) add(fa[i], i); 84 } 85 for(Rint i = 1;i <= n;i ++){ 86 scanf("%d", p + i); 87 if(head[i]) p[i] = (LL) p[i] * inv % mod; 88 else v[++ tot] = p[i]; 89 } 90 sort(v + 1, v + tot + 1); 91 for(Rint i = 1;i <= n;i ++) 92 if(!head[i]) p[i] = lower_bound(v + 1, v + tot + 1, p[i]) - v; 93 dfs(1); 94 getans(root[1], 1, n); 95 printf("%d", ans); 96 }
来源:https://www.cnblogs.com/AThousandMoons/p/10893829.html