题目:http://codeforces.com/contest/293/problem/E
仍旧是点分治。用容斥,w的限制用排序+两个指针解决, l 的限制就用树状数组。有0的话就都+1,相对大小不变。
切勿每次memset!!!会T得不行。add(sta[ l ].len)即可,但要判一下(l==r)以防不测。(真的有那种数据!)
最后注意树状数组的范围是L(即L+1),不是n。不然可以尝试:
2 10 12
1 5
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=1e5+5;
int n,L,W,hd[N],xnt,to[N<<1],nxt[N<<1],w[N<<1];
int mn,rt,f[N],l,r,siz[N],lm;
ll ans;
bool vis[N];
struct Sta{
int w,len;Sta(int w=0,int l=0):w(w),len(l) {}
bool operator< (const Sta &b)const
{return w==b.w?len<b.len:w<b.w;}
}sta[N];
void add(int x,int y,int z)
{
to[++xnt]=y;nxt[xnt]=hd[x];w[xnt]=z;hd[x]=xnt;
to[++xnt]=x;nxt[xnt]=hd[y];w[xnt]=z;hd[y]=xnt;
}
void getrt(int cr,int fa,int s)
{
siz[cr]=1;int mx=0;
for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)
{
getrt(v,cr,s);siz[cr]+=siz[v];mx=max(mx,siz[v]);
}
mx=max(mx,s-siz[cr]);
if(mx<mn)mn=mx,rt=cr;
}
void add(int x,int k){x++;for(;x<=lm;x+=(x&-x))f[x]+=k;}//x<=L+1!!!
int query(int x){x++;int ret=0;for(;x;x-=(x&-x))ret+=f[x];return ret;}
void dfs(int cr,int fa,int pw,int pl)
{
if(pw>W||pl>L)return;
sta[++r]=Sta(pw,pl);add(sta[r].len,1);
for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)
dfs(v,cr,pw+w[i],pl+1);
}
ll calc(int cr,int pw,int pl)
{
l=1;r=0;dfs(cr,0,pw,pl);
sort(sta+l,sta+r+1);///////
// printf("l=%d r=%d\n",l,r);
// for(int i=l;i<=r;i++)printf("sta[%d].w=%d .len=%d\n",i,sta[i].w,sta[i].len);
ll ret=0;
while(l<r)
if(sta[l].w+sta[r].w>W)add(sta[r--].len,-1);
else ret+=query(L-sta[l].len)-(sta[l].len<=L-sta[l].len),
// printf("l=%d r=%d query(%d)=%d\n"
// ,l,r,L-sta[l].len,query(L-sta[l].len)),
add(sta[l++].len,-1);
if(l==r)add(sta[l].len,-1);
// memset(f,0,sizeof f);///TLE!!!!!
return ret;
}
void solve(int cr,int s)
{
// printf("cr=%d\n",cr);
vis[cr]=1;ans+=calc(cr,0,0);
// printf("cr=%d ans=%lld\n",cr,ans);
for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]])
{
ans-=calc(v,w[i],1);
// printf("(cr=%d ans=%lld)(v=%d)\n",cr,ans,v);
int ts=(siz[cr]>siz[v]?siz[v]:s-siz[cr]);
mn=N;getrt(v,0,ts);solve(rt,ts);
}
}
int main()
{
scanf("%d%d%d",&n,&L,&W);lm=L+1;
for(int i=2,y,z;i<=n;i++)
{
scanf("%d%d",&y,&z);add(i,y,z);
}
mn=N;getrt(1,0,n);solve(rt,n);
printf("%I64d\n",ans);
return 0;
}
来源:oschina
链接:https://my.oschina.net/u/4358837/blog/3864557