这是一份甚至不能稳定通过的代码
我们要求的是树上有多少条路径被给定路径完全覆盖,我们显然可以把这个问题转化为对于每一个点求出经过这个点的所有给定路径并的大小,这样我们没有区分\(u<v\),所以最后我们还需要\(/2\)
考虑一下如何求树链并的大小,一个非常经典的做法是虚树
我们只需要把所有树链的两个端点都拿出来,构建一棵虚树,虚树上的节点个数就是树链并的大小
当然我们并不能对于每一个点暴力找一下经过这个点的路径之后算一下虚树的大小,我们考虑刚才那个做法的本质是什么
有经验的话就能看出来上面那个做法本质上是求虚树上所有\(dfs\)序相邻的点的距离
于是我们可以使用\(dsu\ on\ tree\)套上一个\(set\)来动态维护虚树的大小
但是由于一些我还没调出来的\(bug\)这份代码以及随机选根,这份代码能获得90到100的好成绩
代码
#include<set> #include<ctime> #include<vector> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define re register #define LL long long #define mp std::make_pair #define max(a,b) ((a)>(b)?(a):(b)) #define min(a,b) ((a)<(b)?(a):(b)) #define set_it std::set<pii>::iterator inline int read() { char c=getchar();int x=0;while(c<'0'||x>'9') c=getchar(); while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x; } const int maxn=1e5+5; typedef std::pair<int,int> pii; std::set<pii> s; std::vector<int> v[maxn],d[maxn],t[maxn]; struct E{int v,nxt;}e[maxn<<1];LL tot=0; int dfn[maxn],son[maxn],sum[maxn],head[maxn],deep[maxn],top[maxn],fa[maxn]; int n,m,num,__,rt,Son,T,ans=0,tax[maxn],xx[maxn],yy[maxn],ma[maxn]; inline void add(int x,int y) { e[++num].v=y;e[num].nxt=head[x];head[x]=num; } void dfs1(int x) { sum[x]=1; for(re int i=head[x];i;i=e[i].nxt) { if(deep[e[i].v]) continue; deep[e[i].v]=deep[x]+1,fa[e[i].v]=x; dfs1(e[i].v);sum[x]+=sum[e[i].v]; if(sum[e[i].v]>sum[son[x]]) son[x]=e[i].v; } } void dfs2(int x,int topf) { top[x]=topf,dfn[x]=++__; if(son[x]) dfs2(son[x],topf); for(re int i=head[x];i;i=e[i].nxt) if(!top[e[i].v]) dfs2(e[i].v,e[i].v); } inline int LCA(int x,int y) { while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]]) std::swap(x,y); x=fa[top[x]]; } if(deep[x]<deep[y]) return x;return y; } inline int dis(int x,int y) {return deep[x]+deep[y]-2*deep[LCA(x,y)];} inline void ins(int x,int v) { set_it it; it=s.find(mp(dfn[x],x)); if(it==s.begin()) { ++it; if(it==s.end()) return; ans+=v*dis(x,(*it).second); return; } ++it; if(it==s.end()) { --it;--it; ans+=v*dis(x,(*it).second); return; } int y=(*it).second; --it,--it; int z=(*it).second; ans-=v*dis(y,z); ans+=v*(dis(x,y)+dis(x,z)); } inline void Add(int i,int v) { if(v==1) { if(!ma[xx[i]]) s.insert(mp(dfn[xx[i]],xx[i])),ins(xx[i],1); if(!ma[yy[i]]) s.insert(mp(dfn[yy[i]],yy[i])),ins(yy[i],1); ma[xx[i]]++,ma[yy[i]]++; return; } if(ma[xx[i]]==1) ins(xx[i],-1),s.erase(mp(dfn[xx[i]],xx[i])); if(ma[yy[i]]==1) ins(yy[i],-1),s.erase(mp(dfn[yy[i]],yy[i])); ma[xx[i]]--,ma[yy[i]]--; } void calc(int x,int now) { for(re int i=0;i<v[x].size();i++) { if(deep[t[x][i]]>deep[now]||tax[v[x][i]]==T) continue; Add(v[x][i],1);tax[v[x][i]]=T; } for(re int i=head[x];i;i=e[i].nxt) { if(deep[e[i].v]<deep[x]||Son==e[i].v) continue; calc(e[i].v,now); } } inline int Dis() { set_it it=s.begin(); int k=(*it).second; it=s.end();--it; return dis(k,(*it).second); } void dfs(int x,int k) { for(re int i=head[x];i;i=e[i].nxt) if(son[x]!=e[i].v&&deep[x]<deep[e[i].v]) dfs(e[i].v,0); if(son[x]) dfs(son[x],1); Son=son[x],calc(x,x);Son=0; for(re int i=0;i<d[x].size();i++) if(tax[d[x][i]]==T) Add(d[x][i],-1),tax[d[x][i]]=0; if(s.size()>=2) tot+=(Dis()+ans+2)/2,--tot; if(!k) { set_it it; for(it=s.begin();it!=s.end();++it) ma[(*it).second]=0; s.clear();++T;ans=0; } } int main() { srand(time(0)); n=read(),m=read(); rt=rand()%n+1; for(re int x,y,i=1;i<n;i++) { x=read(),y=read(),add(x,y),add(y,x); } T=1;deep[rt]=1,dfs1(rt),dfs2(rt,rt); for(re int l,i=1;i<=m;i++) { xx[i]=read(),yy[i]=read();l=LCA(xx[i],yy[i]); v[xx[i]].push_back(i);v[yy[i]].push_back(i); t[xx[i]].push_back(l),t[yy[i]].push_back(l); d[fa[l]].push_back(i); } dfs(rt,1);printf("%lld\n",tot/2ll); return 0; }