问题描述
小 C 有一个集合 \(S\),里面的元素都是小于 \(M\) 的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为 \(N\) 的数列,数列中的每个数都属于集合 \(S\)。
小 C 用这个生成器生成了许多这样的数列。但是小 C 有一个问题需要你的帮助:给定整数 \(x\),求所有可以生成出的,且满足数列中所有数的乘积 \(\bmod M\) 的值等于 \(x\) 的不同的数列的有多少个。小 C 认为,两个数列 \(\{A_i\}\) 和 \(\{B_i\}\) 不同,当且仅当至少存在一个整数 \(i\),满足 \(A_i \neq B_i\)。另外,小 C 认为这个问题的答案可能很大,因此他只需要你帮助他求出答案 \(\mod 1004535809\) 的值就可以了。
题解
一道好题!
第一类部分分 - \(O(nm^2)\) \(\mathrm{DP}\)
设 \(dp[i][j]\) 代表选取了 \(i\) 个之后,膜 \(M\) 为 \(j\) 的方案数。
转移方程: \(dp[i][j]=\sum\limits_{k=1}^{|S|}{\sum\limits_{p=0}^{M-1}{dp[i-1][S_k \times p \bmod M]}}\)
边界条件: \(dp[1][j]=[j \in S]\)
时间复杂度 \(O(nm^2)\) ,期望得分 \(10\) 分。
第二类部分分 - \(O(m^2 \log_2 n)\) \(\mathrm{DP}\)
设 \(f[i][j]\) 代表选取 \(2^i\) 个之后,膜 \(M\) 为 \(j\) 的方案数。
预处理 \(f\) 之后,可以类似于快速幂地求解。
正解
对于第二类部分分,发现在 \(\sum\) 下是 \(A \times B = C\) ,不是很好处理。
考虑对其取对数,令 \(a=log_g A\),则为 \(a+b=c\) ,卷积形式,可以用 NTT 优化。
至于映射关系,在膜 \(M\) 意义下,根据原根的性质,取对数可以转化为 \(g^{a} \bmod M\) , \(a\) 的取值范围为 \([0,M-1)\) 。
然后 NTT 的一些实现细节,还需要再想想。
\(\mathrm{Code}\)
10pts
#include<bits/stdc++.h> using namespace std; template <typename Tp> void read(Tp &x){ x=0;char ch=1;int fh=1; while(ch!='-'&&(ch>'9'||ch<'0')) ch=getchar(); if(ch=='-') fh=-1,ch=getchar(); while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); x*=fh; } const int maxS=8000+7; const int mod=1004535809; int n,m,x,S; int a[maxS]; void Init(void){ read(n);read(m);read(x);read(S); for(int i=1;i<=S;i++) read(a[i]); } int dp[1007][maxS]; void Work(void){ for(int i=1;i<=S;i++) dp[1][a[i]]++; for(int i=2;i<=n;i++){ for(int j=1;j<=S;j++){ for(int k=0;k<m;k++){ int p=(long long)k*a[j]%m; dp[i][p]=(dp[i][p]+dp[i-1][k])%mod; } } } printf("%d\n",dp[n][x]); } int main(){ Init();Work(); return 0; }
60pts
#include<bits/stdc++.h> using namespace std; template <typename Tp> void read(Tp &x){ x=0;char ch=1;int fh=1; while(ch!='-'&&(ch>'9'||ch<'0')) ch=getchar(); if(ch=='-') fh=-1,ch=getchar(); while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); x*=fh; } const int maxS=8000+7; const int mod=1004535809; int n,m,x,S; int a[maxS]; int dp[maxS],opt[maxS]; //bool exist[maxS]; void Init(void){ read(n);read(m);read(x);read(S); for(int i=1;i<=S;i++) read(a[i]),dp[a[i]]=1; } int tmp[maxS]; int power(int x,int p,int mod){ int res(1); while(p){ if(p&1) res=(long long)res*x%mod;p>>=1; x=(long long)x*x%mod; } return res; } void mul(int *f,int *g,int *res){ for(int i=0;i<m;i++){ // if(!exist[i]) continue; for(int j=0;j<m;j++){ int p=(long long)i*j%m; tmp[p]=(tmp[p]+(long long)f[i]*g[j]%mod)%mod; } } for(int i=0;i<m;i++) res[i]=tmp[i],tmp[i]=0; } void fpow(int p){ while(p){ if(p&1) mul(dp,opt,opt); mul(dp,dp,dp);p>>=1; } } void Work(void){ opt[1]=1;fpow(n); printf("%d\n",opt[x]); } int main(){ Init(); Work(); return 0; }
100pts
#include<bits/stdc++.h> using namespace std; template <typename Tp> void read(Tp &x){ x=0;char ch=1;int fh=1; while(ch!='-'&&(ch>'9'||ch<'0')) ch=getchar(); if(ch=='-') fh=-1,ch=getchar(); while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); x*=fh; } const int maxS=16000+7; const int mod=1004535809; int n,m,x,S; int a[maxS]; int dp[maxS<<1],opt[maxS<<1]; int power(int x,int p,int mod){ int res(1); while(p){ if(p&1) res=(long long)res*x%mod;p>>=1; x=(long long)x*x%mod; } return res; } bool check(int g,int p){ int q=p-1; for(int i=2;(long long)i*i<=q;i++){ if(q%i==0&&(power(g,i,p)==1||power(g,q/i,p)==1)) return false; } return true; } int getG(int k){ for(int i=2;i<=100;i++) if(check(i,k)) return i; return -1; } int pos[maxS]; int Gm,Gd=3; int invG=power(Gd,mod-2,mod); int tr[maxS<<1]; int lim; void NTT(int *f,int type){ for(int i=0;i<lim;i++) if(i<tr[i]) swap(f[i],f[tr[i]]); for(int dlen=2;dlen<=lim;dlen<<=1){ int len=dlen>>1,w; if(type==1) w=power(Gd,(mod-1)/dlen,mod); else w=power(invG,(mod-1)/dlen,mod); for(int k=0;k<lim;k+=dlen){ int buf=1; for(int l=0;l<len;l++){ int LF=f[k+l],RF=(long long)buf*f[len+k+l]%mod; f[k+l]=(LF+RF)%mod,f[k+l+len]=(LF-RF+mod)%mod; buf=(long long)buf*w%mod; } } } if(type==-1){ int inv=power(lim,mod-2,mod); for(int i=0;i<lim;i++) f[i]=(long long)f[i]*inv%mod; } } void Init(void){ read(n);read(m);read(x);read(S); Gm=getG(m); // printf("** GM = %d , invG = %d\n",Gm,invG); for(int i=0;i<m-1;i++){ int Pos=power(Gm,i,m); pos[Pos]=i; } for(int i=1;i<=S;i++){ read(a[i]); if(a[i]) dp[pos[a[i]]]++; } // for(int i=1;i<=m;i++){ // printf("%d %d\n",dp[i],pos[i]); // } } int tmp[maxS],FF[maxS<<1],GG[maxS<<1]; void mul(int *f,int *g,int *res){ for(int i=0;i<lim;i++) FF[i]=f[i],GG[i]=g[i]; NTT(FF,1);NTT(GG,1); for(int i=0;i<lim;i++) FF[i]=(long long)FF[i]*GG[i]%mod; NTT(FF,-1); for(int i=0;i<m-1;i++) res[i]=(FF[i]+FF[i+m-1])%mod; } void fpow(int p){ while(p){ if(p&1) mul(opt,dp,opt); mul(dp,dp,dp);p>>=1; } } void Work(void){ for(lim=1;lim<=m*2;lim<<=1); // printf("** lim = %d\n",lim); for(int i=0;i<lim;i++) tr[i]=(tr[i>>1]>>1)|((i&1)?lim>>1:0); opt[pos[1]]=1;fpow(n); printf("%d\n",opt[pos[x]]); } int main(){ // printf("**%d\n",getG(12289)); Init(); Work(); return 0; }
来源:https://www.cnblogs.com/liubainian/p/12151558.html