[笔记]点分治

一笑奈何 提交于 2019-11-28 22:47:25

基本思路:点分治,是一种针对可带权树上简单路径统计问题的算法。对于一个节点,只解决经过这棵子树的根节点的路径,对于子节点问题下推子树。

//当初的主要问题是vis[]在干什么qwq,终于知道了 
#include<iostream>
#include<cstdio>
#include<algorithm>
#define R register int
using namespace std;
#define ull unsigned long long
#define ll long long
#define pause (for(R i=1;i<=10000000000;++i))
#define In freopen("NOIPAK++.in","r",stdin)
#define Out freopen("out.out","w",stdout)
namespace Fread {
static char B[1<<15],*S=B,*D=B;
#ifndef JACK
#define getchar() (S==D&&(D=(S=B)+fread(B,1,1<<15,stdin),S==D)?EOF:*S++)
#endif
inline int g() {
  R ret=0,fix=1; register char ch; while(!isdigit(ch=getchar())) fix=ch=='-'?-1:fix;
  if(ch==EOF) return EOF; do ret=ret*10+(ch^48); while(isdigit(ch=getchar())); return ret*fix;
} inline bool isempty(const char& ch) {return (ch<=36||ch>=127);}
inline void gs(char* s) {
  register char ch; while(isempty(ch=getchar()));
  do *s++=ch; while(!isempty(ch=getchar()));
}
} using Fread::g; using Fread::gs;
namespace Jack {
const int N=10010,Inf=0x3f3f3f3f;
int n,m,cnt,sum,rt,tot;
int vr[N<<1],nxt[N<<1],w[N<<1],fir[N],sz[N],mx[N],d[N],a[N],b[N],q[110];
bool ans[110],vis[N];
inline bool cmp(int a,int b) {return d[a]<d[b];}//按路径长度排序 
inline void add(int u,int v,int ww) {vr[++cnt]=v,nxt[cnt]=fir[u],w[cnt]=ww,fir[u]=cnt;}
inline void getrt(int u,int fa) { sz[u]=1,mx[u]=0;//找根节点 
  for(R i=fir[u];i;i=nxt[i]) { R v=vr[i];
    if(v==fa||vis[v]) continue;//若是father(vis避免扫回父亲),就continue 
    getrt(v,u); sz[u]+=sz[v];//合并子树的size 
    mx[u]=max(mx[u],sz[v]);//取max 
  } mx[u]=max(mx[u],sum-sz[u]);
  if(!rt||mx[u]<mx[rt]) rt=u;//选根节点 
}
inline void dfs(int u,int fa,int top) {
  a[++tot]=u;//将子树中的点添加到队列中 
  b[u]=top;//记录所属次级子树(即本次分治节点的子树)的根节点 
  for(R i=fir[u];i;i=nxt[i]) { R v=vr[i];
    if(v==fa||vis[v]) continue;
    d[v]=d[u]+w[i]; dfs(v,u,top);
  }
}
inline void calc(int u) {//计算经过u的路径条数 
  tot=0; a[++tot]=u;//初始化队列 
  d[u]=0; b[u]=u;
  for(R i=fir[u];i;i=nxt[i]) { R v=vr[i];
    if(vis[v]) continue;//不访问已经分治过的father 
    d[v]=w[i]; dfs(v,u,v);
  } sort(a+1,a+tot+1,cmp);//按到当前根的距离排序 
  for(R i=1;i<=m;++i) {
    if(ans[i]) continue;
    R l=1,r=tot; //双指针扫一遍 
    while(l<r) {
      if(d[a[l]]+d[a[r]]>q[i]) --r;//过大则左移右指针 
      else if(d[a[l]]+d[a[r]]<q[i]) ++l;//过小右移左指针 
      else if(b[a[l]]==b[a[r]]) //同属于一棵子树 
        if(d[a[r]]==d[a[r-1]]) --r;//右边权值相等左移右指针 
        else ++l; 
      else {ans[i]=true; break;}  
    }
  }
}
inline void solve(int u) { vis[u]=true; //已经过,打标记 
  calc(u);
  for(R i=fir[u];i;i=nxt[i]) { R v=vr[i];
    if(vis[v]) continue;//vis[u],表示已经分治过的父亲。 
    sum=sz[v]; rt=0;
    getrt(v,0);
    solve(rt);//传入重心,solve 
  }
}
void main() {
  n=g(),m=g(); for(R i=1,u,v,w;i<n;++i) 
    u=g(),v=g(),w=g(),add(u,v,w),add(v,u,w);
  for(R i=1;i<=m;++i) q[i]=g();
  mx[rt]=sum=n;
  getrt(1,0);
  solve(rt);
  for(R i=1;i<=m;++i) 
    ans[i]?printf("AYE\n"):printf("NAY\n");
}
}

signed main() {
  Jack::main();
}

后来又重新学了学点分治,改了改写法。
之前的写法其实求子树的\(size\)那里是不对的(就是会让下一次传进去的\(sum\)是错误的)

#include<bits/stdc++.h> 
#define R register int
using namespace std;
namespace Luitaryi {
template<class I> inline I g(I& x) { x=0; register I f=1;
  register char ch; while(!isdigit(ch=getchar())) f=ch=='-'?-1:f;
  do x=x*10+(ch^48); while(isdigit(ch=getchar())); return x*=f;
} const int N=10010,M=100,Inf=1e+9;
int n,m,q[M],cnt,c,sum,rt,tot; bool ans[M],vis[N],mem[10000010];
int vr[N<<1],nxt[N<<1],fir[N],w[N<<1],d[N],mx[N],sz[N],buf[N],dis[N];
inline void add(int u,int v,int ww) {
  vr[++c]=v,nxt[c]=fir[u],w[c]=ww,fir[u]=c;
  vr[++c]=u,nxt[c]=fir[v],w[c]=ww,fir[v]=c;
}
inline void getsz(int u,int fa) {
  sz[u]=1,mx[u]=0; for(R i=fir[u];i;i=nxt[i]) { R v=vr[i];
    if(vis[v]||v==fa) continue;
    getsz(v,u); sz[u]+=sz[v];
    mx[u]=max(mx[u],sz[v]);
  } mx[u]=max(mx[u],sum-sz[u]);
  if(mx[u]<mx[rt]) rt=u;
}
inline void getdis(int u,int fa) { dis[++cnt]=d[u]; 
  for(R i=fir[u];i;i=nxt[i]) { R v=vr[i];
    if(vis[v]||v==fa) continue;
    d[v]=d[u]+w[i]; getdis(v,u);
  }
}
inline void solve(int u,int fa) { tot=0;
  mem[0]=true,buf[++tot]=0,vis[u]=true;
  for(R i=fir[u];i;i=nxt[i]) { R v=vr[i];
    if(vis[v]||v==fa) continue;
    d[v]=w[i]; getdis(v,u);
    for(R i=1;i<=cnt;++i) 
      for(R j=1;j<=m;++j) if(q[j]>=dis[i])
        ans[j]|=mem[q[j]-dis[i]];
    for(R k=1;k<=cnt;++k) buf[++tot]=dis[k],mem[dis[k]]=true;
    cnt=0;
  } while(tot) mem[buf[tot]]=false,--tot;
  for(R i=fir[u];i;i=nxt[i]) { R v=vr[i];
    if(vis[v]||v==fa) continue;
    sum=sz[v]; rt=0,mx[rt]=Inf;
    getsz(v,u),getsz(rt,-1);
    solve(rt,u);
  }
}
inline void main() {
  g(n),g(m); for(R i=1,u,v,w;i<n;++i) g(u),g(v),g(w),add(u,v,w);
  for(R i=1;i<=m;++i) g(q[i]); sum=n,rt=0,mx[rt]=Inf;
  getsz(1,-1),getsz(rt,-1); solve(rt,-1);
  for(R i=1;i<=m;++i) ans[i]?puts("AYE"):puts("NAY");
} 
} signed main() {Luitaryi::main(); return 0;}

2019.08.29
71

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!