Description
小凸和小方相约玩密室逃脱,这个密室是一棵有 \(n\) 个节点的完全二叉树,每个节点有一个灯泡。点亮所有灯泡即可逃出密室。每个灯泡有个权值 \(Ai\) ,每条边也有个权值 \(bi\) 。点亮第 1 个灯泡不需要花费,之后每点亮 1 个新的灯泡 \(V\) 的花费,等于上一个被点亮的灯泡 \(U\) 到这个点 \(V\) 的距离 \(Du,v\),乘以这个点的权值 \(Av\) 。在点灯的过程中,要保证任意时刻所有被点亮的灯泡必须连通,在点亮一个灯泡后必须先点亮其子树所有灯泡才能点亮其他灯泡。
请告诉他们,逃出密室的最少花费是多少。
Input
第1行包含1个数 \(n\) ,代表节点的个数
第2行包含 \(n\) 个数,代表每个节点的权值 \(ai\) 。( \(i=1,2,…,n\) )
第3行包含 \(n-1\) 个数,代表每条边的权值 \(bi\) ,第 \(i\) 号边是由第 \((i+1)/2\) 号点连向第 \(i+1\) 号点的边。( \(i=l,2,...N-1\) )
Output
输出包含1个数,代表最少的花费。
Sample Input
3
5 1 2
2 1
Sample Output
5
HINT
对于 \(100\%\) 的数据,\(1 \leq N \leq 2 \times 10^5\),\(1<Ai,Bi \leq 10^5\)
想法
明显的树形 \(DP\) 。
但注意题中的2个坑点 !!!!!!
第一个点亮的节点不一定是1号点!
“完全二叉树”的意思的第 \(i\) 个点的父亲是 \(i/2\) ,但不保证所有非叶子节点都有两个孩子!
先假设第一个点亮的点是 1 。
那么树形 \(dp\) 的状态为:
\(dp[i][j]\) 表示当前 \(i\) 已被点亮,开始点亮以 \(i\) 为根的子树,将其全点亮后,最后一个点跑到 \(j\) 去点亮 \(j\) 的最少花费。
转移也挺显然的,考虑左右子哪个先点亮就行了,记忆化搜索。
由于在这种情况下,对每个 \(i\) ,有用的 \(dp[i][j]\) 中的 \(j\) 为其所有祖先的另一个孩子,不超过 \(O(logn)\) 个,所以总状态数 \(O(nlogn)\) ,不会超时。
交一发,\(WA\) 了。
于是开始换根。
对于先点亮的那个点,还是要先把它的子树点亮,然后再点亮它的父节点。
这时对每个 \(i\) ,有用的 \(dp[i][j]\) 中的 \(j\) 除了所有祖先的另一个孩子,还有它所有的祖先,但还是 \(O(logn)\) 级别,总复杂度 \(O(nlogn)\)。
记忆化搜索,然后超时了 \(qwq\)
那就不记忆化了(用 \(map\) 常数过大【捂脸】)
重新设状态——
\(f[i][j]\) 表示 \(dp[i][y]\) ,其中 \(y\) 为 \(i\) 的第 \(j+1\) 个祖先。
\(g[i][j]\) 表示 \(dp[i][z]\) ,其中 \(z\) 为 \(i\) 的第 \(j+1\) 个祖先的另一个孩子。
\(O(nlogn)\) 时间能把这些值都算出来,然后再换根。
交一发, \(WA\) 了。
发现不一定每个非叶子节点都有2个孩子,于是又改了改细节。终于 \(A\) 掉了!
代码
细节极多 【害怕】
#include<cstdio> #include<iostream> #include<algorithm> using namespace std; int read(){ int x=0; char ch=getchar(); while(!isdigit(ch)) ch=getchar(); while(isdigit(ch)) x=x*10+ch-'0',ch=getchar(); return x; } const int N = 200005; typedef long long ll; int n,a[N],b[N]; ll f[N][20],g[N][20]; ll ans; void dfs(int x,ll cur){ //换根 int l=x*2,r=x*2+1; if(x!=1){ ll now; if(r<=n) now=min(1ll*a[l]*b[l]+g[l][0]+f[r][1],1ll*a[r]*b[r]+g[r][0]+f[l][1]); else if(l==n) now=1ll*a[l]*b[l]+f[l][1]; else now=f[x][0]; ans=min(ans,now+cur); } if(l>n) return; if(l==n) dfs(l,cur+1ll*b[x]*a[x/2]); else{ dfs(l,cur+1ll*a[r]*b[r]+f[r][1]); dfs(r,cur+1ll*a[l]*b[l]+f[l][1]); } } int main() { n=read(); for(int i=1;i<=n;i++) a[i]=read(); for(int i=2;i<=n;i++) b[i]=read(); for(int i=n;i>0;i--){ if(i*2>n){ int x=i/2,last=(i&1) ? i-1 : i+1; ll s=b[i]; for(int j=0;x>=0;j++,x/=2){ f[i][j]=s*a[x]; g[i][j]=1ll*(s+b[last])*a[last]; s+=b[x]; last=(x&1) ? x-1 : x+1; if(x==0) break; } continue; } else if(i*2==n){ for(int j=0,x=i/2;x>=0;j++,x/=2){ f[i][j]=1ll*a[n]*b[n]+f[n][j+1]; g[i][j]=1ll*a[n]*b[n]+g[n][j+1]; if(x==0) break; } continue; } int l=i*2,r=l+1; for(int j=0,x=i/2;x>=0;j++,x/=2){ f[i][j]=min(1ll*a[l]*b[l]+g[l][0]+f[r][j+1],1ll*a[r]*b[r]+g[r][0]+f[l][j+1]); g[i][j]=min(1ll*a[l]*b[l]+g[l][0]+g[r][j+1],1ll*a[r]*b[r]+g[r][0]+g[l][j+1]); if(x==0) break; } } ans=f[1][0]; dfs(1,0); printf("%lld\n",ans); return 0; }