[BZOJ3277] 串
Description
现在给定你n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串(注意包括本身)。
Solution
首先将所有串连接起来,预处理出后缀数组和高度数组。
显然直接主席树可以很容易做到 \(O(n \log^2 n)\) 。对于每一个后缀的位置,二分一个 LCP 长度,找到这个 LCP 长度对应的区间,检查这个区间是否合法来调节二分边界。
注意在这个做法里,瓶颈不在于主席树,因为主席树的功能完全可以用双指针预处理一个数组来替代。瓶颈在于,实质上使用了一个二分套二分的做法。
但我们有更好的做法。
引理:按照原始顺序,如果第 \(i\) 个后缀有 \(x\) 个前缀能被 \(k\) 个串包含,那么第 \(i+1\) 个后缀至少有 \(x-1\) 个前缀能被 \(k\) 个串包含。
那么我们先用双指针预处理 \(jmp[i]\) 代表按照后缀排序,最大的 \(j\) 使得 \([j,i]\) 这个后缀区间合法。
到第 \(i\) 个后缀的时候我们就从后缀 \(i-1\) 的答案开始向上枚举,用二分+ST表找出它左右边第一个高度比当前枚举值小的位置,判断这个区间的合法性来决定是否继续枚举,均摊时间复杂度 \(O(nlogn)\) 。
\(O(\log n)\) 解法
#include <bits/stdc++.h> using namespace std; #define int long long const int N = 400005; int n,m=N/2,sa[N],y[N],u[N],v[N],o[N],r[N],h[N],T,nstr,k; int str[N],Log2[N],bel[N],buf[N],bcnt,jmp[N],mx[N],ans[N],tow[N]; char tstr[N]; struct St { int a[N][21]; void build(int *src,int n) { for(int i=1;i<=n;i++) a[i][0]=src[i]; for(int i=1;i<=20;i++) for(int j=1;j<=n-(1<<i)+1;j++) a[j][i]=min(a[j][i-1],a[j+(1<<(i-1))][i-1]); } int query(int l,int r) { if(l>r) return 0; int j=Log2[r-l+1]; return min(a[l][j],a[r-(1<<j)+1][j]); } } st; int lbound(int cen,int val) { int l=1,r=cen; while(r>l) { int mid=(l+r)/2; if(st.query(mid+1,cen)>=val) r=mid; else l=mid+1; } return l; } int rbound(int cen,int val) { int l=cen+1,r=n+1; while(r>l) { int mid=(l+r)/2; if(st.query(cen+1,mid)>=val) l=mid+1; else r=mid; } return l-1; } signed main(){ for(int i=1;i<=200000;i++) Log2[i]=log2(i); scanf("%lld%lld",&nstr,&k); for(int i=1;i<=nstr;i++) { scanf("%s",tstr); int len=strlen(tstr); for(int j=0;j<len;j++) str[j+n+1]=tstr[j],bel[j+n+1]=i,tow[j+n+1]=n+len; n+=len+1; str[n]=127+i; } for(int i=1;i<=n;i++) u[str[i]]++; for(int i=1;i<=m;i++) u[i]+=u[i-1]; for(int i=n;i>=1;i--) sa[u[str[i]]--]=i; r[sa[1]]=1; for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+(str[sa[i]]!=str[sa[i-1]]); for(int l=1;r[sa[n]]<n;l<<=1) { memset(u,0,sizeof u); memset(v,0,sizeof v); memcpy(o,r,sizeof r); for(int i=1;i<=n;i++) u[r[i]]++, v[r[i+l]]++; for(int i=1;i<=n;i++) u[i]+=u[i-1], v[i]+=v[i-1]; for(int i=n;i>=1;i--) y[v[r[i+l]]--]=i; for(int i=n;i>=1;i--) sa[u[r[y[i]]]--]=y[i]; r[sa[1]]=1; for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+((o[sa[i]]!=o[sa[i-1]])||(o[sa[i]+l]!=o[sa[i-1]+l])); } { int i,j,k=0; for(int i=1;i<=n;h[r[i++]]=k) for(k?k--:0,j=sa[r[i]-1];str[i+k]==str[j+k];k++); } st.build(h,n); bcnt=1; buf[bel[sa[n]]]++; for(int i=n,j=n;i>=1;--i) { while(bcnt<k && j>0) { --j; if(buf[bel[sa[j]]]==0) ++bcnt; buf[bel[sa[j]]]++; } jmp[i]=j; if(buf[bel[sa[i]]]==1) --bcnt; buf[bel[sa[i]]]--; } // for(int i=1;i<=n;i++) cout<<jmp[i]<<" "; cout<<endl; for(int i=1;i<=n;i++) { for(int j=max(1ll,mx[i-1]);j<=n;j++) { int lb=lbound(r[i],j), rb=rbound(r[i],j); //cout<<i<<" "<<r[i]<<" "<<j<<" "<<lb<<" "<<rb<<endl; if(jmp[rb]<lb || j>tow[i]-i+1) { mx[i]=j-1; break; } } } //for(int i=1;i<=n;i++) cout<<mx[i]<<" "; //cout<<endl; for(int i=1;i<=n;i++) { ans[bel[i]]+=mx[i]; } for(int i=1;i<=nstr;i++) printf("%lld ",ans[i]); }
\(O(\log^2 n)\) 解法 (TLE)
#include <bits/stdc++.h> using namespace std; #define int long long const int N = 400005; int n,m=N/2,sa[N],y[N],u[N],v[N],o[N],r[N],h[N],jmp[N],buf[N],bel[N],bcnt; int nstr,k; int str[N],ans[N],tow[N],LOG2[N]; char tstr[N]; struct St { int a[N][21]; void build(int *src,int n) { for(int i=1;i<=n;i++) a[i][0]=src[i]; for(int i=1;i<=20;i++) for(int j=1;j<=n-(1<<i)+1;j++) a[j][i]=min(a[j][i-1],a[j+(1<<(i-1))][i-1]); } int query(int l,int r) { if(l>r) return 0; int j=LOG2[r-l+1]; return min(a[l][j],a[r-(1<<j)+1][j]); } } st; int lbound(int cen,int val) { int l=1,r=cen; while(r-l) { int mid=(l+r)/2; if(st.query(mid+1,cen)>=val) r=mid; else l=mid+1; } return l; } int rbound(int cen,int val) { int l=cen+1,r=n+1; while(r-l) { int mid=(l+r)/2; if(st.query(cen+1,mid)>=val) l=mid+1; else r=mid; } return l-1; } signed main(){ for(int i=1;i<=200000;i++) LOG2[i]=log2(i); scanf("%d%d",&nstr,&k); for(int i=1;i<=nstr;i++) { scanf("%s",tstr); int tstrlength = strlen(tstr); for(int j=0;j<tstrlength;j++) str[n+j+1]=tstr[j],bel[n+j+1]=i,tow[n+j+1]=n+tstrlength; n+=tstrlength+1; str[n]=127+i; } for(int i=1;i<=n;i++) u[str[i]]++; for(int i=1;i<=m;i++) u[i]+=u[i-1]; for(int i=n;i>=1;i--) sa[u[str[i]]--]=i; r[sa[1]]=1; for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+(str[sa[i]]!=str[sa[i-1]]); for(int l=1;r[sa[n]]<n;l<<=1) { memset(u,0,sizeof u); memset(v,0,sizeof v); memcpy(o,r,sizeof r); for(int i=1;i<=n;i++) u[r[i]]++, v[r[i+l]]++; for(int i=1;i<=n;i++) u[i]+=u[i-1], v[i]+=v[i-1]; for(int i=n;i>=1;i--) y[v[r[i+l]]--]=i; for(int i=n;i>=1;i--) sa[u[r[y[i]]]--]=y[i]; r[sa[1]]=1; for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+((o[sa[i]]!=o[sa[i-1]])||(o[sa[i]+l]!=o[sa[i-1]+l])); } { int i,j,k=0; for(int i=1;i<=n;h[r[i++]]=k) for(k?k--:0,j=sa[r[i]-1];str[i+k]==str[j+k];k++); } st.build(h,n); buf[bel[sa[n]]]=1; bcnt++; for(int i=n,j=n;i>=1;--i) { while(bcnt<k && j>0) { --j; if(buf[bel[sa[j]]]==0) bcnt++; buf[bel[sa[j]]]++; } jmp[i]=j; buf[bel[sa[i]]]--; if(buf[bel[sa[i]]]==0) bcnt--; } for(int i=1;i<=n;i++) { int l=1,r=tow[sa[i]]-sa[i]+2; while(r>l) { int mid=(l+r)/2; int lb=lbound(i,mid),rb=rbound(i,mid); if(jmp[rb]>=lb) l=mid+1; else r=mid; } //cout<<i<<" "<<l-1<<endl; ans[bel[sa[i]]]+=l-1; } for(int i=1;i<=nstr;i++) printf("%lld ",ans[i]); }