树形DP有一个独特的优化,就是通过递归,枚举目前有效的元素个数,求dp[ i ][ j ] (表示 选取以i为根的子树中有选取j个元素的最大取值)
(搭配 siz 数组表示当前该节点的总共子孙数)
1.hdu1561(树形依赖背包裸题)
注意 siz 数组的运用,以及 u 点选择的节点数时要逆向枚举,就像01背包
复杂度看似O(n^3),实际是 O( n^2 ) 左右。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <vector> using namespace std; const int maxn = 250; vector<int> g[maxn]; int dp[maxn][maxn]; int val[maxn]; int siz[maxn]; int n,m; //dp[i][j] 表示 选取以i为根的子树中有选取j个元素的最大取值 void dfs(int u){ siz[u]=1; dp[u][1] = val[u]; for(int i=0; i<g[u].size(); i++){ int v = g[u][i]; dfs(v); //这里的siz[u]不包括siz[v] ,并且是把效率很低的2^n举法用01背包来做 for(int i=siz[u]; i>=1; i--){ //这里就像01背包里,避免由这个点的情况递推这个点的更佳情况 for(int j=1; j<=siz[v]&&i+j<=m; j++){ //就比如要避免刚刚还说是从v取3个点推出的最优 dp[u][i+j] = max(dp[u][i+j], dp[u][i]+dp[v][j]); //后面又从前面的dp值而只从j中取1个点得出错误的更优解 } } siz[u] += siz[v]; } } int main(){ while(scanf("%d%d",&n,&m)!=EOF){ if(n==0&&m==0) break; for(int i=0; i<=n; i++){ for(int j=0; j<=n; j++){ dp[i][j] = 0; } g[i].clear(); } int t; for(int i=1; i<=n; i++){ scanf("%d%d",&t,val+i); g[t].push_back(i); } m++; dfs(0); printf("%d\n", dp[0][m]); } }
2.codeforces 815C (树形dp)
这个选取树上物品可以不需要有父子关系的,但使用优惠券和父子关系有关,所以可以把 dp数组多增加一维,表示是否能够使用优惠券。
只需要设置默认值为 inf ,再这样初始化:
dp[u][0][0]=0;
dp[u][1][0]=c[u];
dp[u][1][1]=c[u]-d[u];
就可以在枚举时考虑到 0 这个元素。
#include <cstdio> #include <algorithm> #include <cstring> #include <vector> using namespace std; const int maxn = 5005; int dp[maxn][maxn][2]; //dp[i][j]表示以i为根的子树中取j个元素的最大值 vector<int> g[maxn]; //再来一维表示是否购买根节点i这个元素,也就是用不用优惠券 int val[maxn],d[maxn]; int siz[maxn]; void dfs(int u){ siz[u] = 1; dp[u][1][1] = val[u]-d[u]; dp[u][1][0] = val[u]; dp[u][0][0] = 0; for(int i=0; i<g[u].size(); i++){ int v=g[u][i]; dfs(v); //这里的siz[u]不包括siz[v] for(int i=siz[u]; i>=0; i--){ //这里的0是为了处理可以不取 for(int j=0; j<=siz[v]; j++){ dp[u][i+j][0] = min(dp[u][i+j][0], dp[u][i][0]+dp[v][j][0]); dp[u][i+j][1] = min(dp[u][i+j][1], dp[u][i][1]+min(dp[v][j][0],dp[v][j][1])); } } siz[u] += siz[v]; } } int main(){ int n,b; scanf("%d%d",&n,&b); scanf("%d%d",val+1,d+1); for(int i=2; i<=n; i++){ int t; scanf("%d%d%d",val+i,d+i,&t); g[t].push_back(i); } memset(dp,0x3f,sizeof(dp)); dfs(1); int ans=n; while(dp[1][ans][1]>b&&dp[1][ans][0]>b){ ans--; //printf("%d %d\n",dp[1][ans][1],dp[1][ans][0] ); } printf("%d\n", ans); }