luoguP5666 树的重心 树状数组

£可爱£侵袭症+ 提交于 2020-01-26 20:12:23

这道题在家里仔细想想还是挺好想的...

考场的时候还是要镇定,给每道题要安排足够的思考时间. 

code: 

#include <cstdio> 
#include <cstring> 
#include <vector>
#include <string>   
#include <algorithm>  
#define N 500004
#define ll long long 
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;  
int n;   
namespace BIT
{    
    int C[N]; 
    int lowbit(int t) { return t&(-t); }     
    void clr() { for(int i=0;i<N;++i) C[i]=0; }
    void update(int x,int v) 
    { 
        if(x<=0) return; 
        while(x<N) C[x]+=v,x+=lowbit(x);   
    }   
    int ask(int x) 
    {     
        int re=0; 
        for(int i=x;i>0;i-=lowbit(i))  re+=C[i]; 
        return re; 
    }
    int query(int l,int r) 
    {
        l=max(1,l),r=min(r,n);      
        return l>r?0:ask(r)-ask(l-1);   
    } 
}; 
struct node
{     
    int id,size;     
    node(int id=0,int size=0):id(id),size(size){}  
};   
ll ans;   
vector<node>G[N];    
bool cmp(node a,node b) { return a.size>b.size; }    
int edges,cnt_max,det,cn;       
int hd[N],to[N<<1],nex[N<<1],size[N<<1],st[N],ed[N],cn_nd[N],ori[N];       
void addedge(int u,int v) 
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;   
}
void clr() 
{
    int i,j; 
    BIT::clr();      
    for(i=1;i<=n;++i) G[i].clear();   
    for(i=0;i<=edges;++i) to[i]=nex[i]=0; 
    for(i=0;i<N;++i) hd[i]=size[i]=ori[i]=cn_nd[i]=0;        
    n=edges=ans=0;    
}     
void dfs(int u,int ff) 
{
    size[u]=1,st[u]=++cn;  
    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i];   
        if(v==ff) continue;     
        dfs(v,u);  
        size[u]+=size[v];   
        G[u].push_back(node(v,size[v]));    
    }     
    ed[u]=cn;   
    if(ff)  G[u].push_back(node(ff,n-size[u]));    
    sort(G[u].begin(),G[u].end(),cmp);          
}
void calc_max(int u,int ff) 
{          
    size[u]=1;          
    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i]; 
        if(v==ff) continue;    
        calc_max(v,u);    
        size[u]+=size[v];  
    } 
    if(size[u]<=det) ++cnt_max;           
}
void calc_sec(int u,int ff) 
{
    size[u]=1;                      
    int flag=G[u][0].id==ff?-1:1;  
    if(G[u].size()==1) G[u].push_back(node(0,0));         
    // if(u==3) printf("qaq  %d %d\n",2*G[u][0].size-n,n-2*G[u][1].size);                                        
    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i]; 
        if(v==ff) continue;            
        int de=0;                                        
        de-=BIT::query(2*G[u][0].size-n,n-2*G[u][1].size);           
        // 计算子树    
        calc_sec(v,u),size[u]+=size[v];    
        // 计算完毕    
        de+=BIT::query(2*G[u][0].size-n,n-2*G[u][1].size);     

        if(flag==-1) 
        {    
            cn_nd[u]-=de;    
        }  
        else
        {
            if(v==G[u][0].id) cn_nd[u]=de;           
        }
    }              
    BIT::update(size[u],1);   
}
void dfs_sec(int u,int ff) 
{       
    BIT::update(size[u],-1);     
    BIT::update(n-size[u],1);   
    if(G[u][0].id==ff)  
    {
        // printf("qaq\n"); 
        cn_nd[u]+=BIT::query(2*G[u][0].size-n,n-2*G[u][1].size);               
        // printf("qaq\n"); 
    }
    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i];   
        if(v==ff) continue;          
        dfs_sec(v,u);      
    }
    BIT::update(size[u],1);   
    BIT::update(n-size[u],-1); 
}
void work() 
{ 
    int i,j; 
    clr();   
    scanf("%d",&n); 
    for(i=1;i<n;++i) 
    {
        int x,y; 
        scanf("%d%d",&x,&y);   
        addedge(x,y),addedge(y,x);  
    }
    dfs(1,0);       
    // 计算由最大儿子的贡献   
    for(i=1;i<=n;++i) 
    {     
        if(G[i][0].size*2<=n) 
        {         
            cnt_max=0;  
            det=n-G[i][0].size*2;  
            for(j=1;j<G[i].size();++j)                        
                calc_max(G[i][j].id,i);    
            ans+=(ll)i*cnt_max;    
            ori[i]+=cnt_max;  
        }
    }
    // 计算次大儿子的贡献   
    calc_sec(1,0);        
    dfs_sec(1,0);     
    ans=0;  
    for(i=1;i<=n;++i) ans+=(ll)i*(ori[i]+cn_nd[i]);  
    printf("%lld\n",ans); 
}
int main() 
{ 
    // setIO("input");     
    int i,j,T; 
    scanf("%d",&T); 
    while(T--) work();   
    return 0;
} 

  

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!