多项式求逆:http://blog.miskcoo.com/2015/05/polynomial-inverse
注意:直接在点值表达下做$B(x) \equiv 2B'(x) - A(x)B'^2(x) \pmod {x^n}$是可以的,但是一定要注意,这一步中有一个长度为n的和两个长度为(n/2)的多项式相乘,因此要在DFT前就扩展FFT点值表达的“长度”到2n,否则会出错(调了1.5个小时)
版本1:
1 #prag\ 2 ma GCC optimize(2) 3 #include<cstdio> 4 #include<algorithm> 5 #include<cstring> 6 #include<vector> 7 #include<cmath> 8 using namespace std; 9 #define fi first 10 #define se second 11 #define mp make_pair 12 #define pb push_back 13 typedef long long ll; 14 typedef unsigned long long ull; 15 const int md=998244353; 16 const int N=2097152; 17 int rev[N]; 18 void init(int len) 19 { 20 int bit=0,i; 21 while((1<<(bit+1))<=len) ++bit; 22 for(i=0;i<len;++i) 23 rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1)); 24 } 25 ll poww(ll a,ll b) 26 { 27 ll base=a,ans=1; 28 for(;b;b>>=1,base=base*base%md) 29 if(b&1) 30 ans=ans*base%md; 31 return ans; 32 } 33 void dft(int *a,int len,int idx)//要求len为2的幂 34 { 35 int i,j,k,t1,t2;ll wn,wnk; 36 for(i=0;i<len;++i) 37 if(i<rev[i]) 38 swap(a[i],a[rev[i]]); 39 for(i=1;i<len;i<<=1) 40 { 41 wn=poww(idx==1?3:332748118,(md-1)/(i<<1)); 42 for(j=0;j<len;j+=(i<<1)) 43 { 44 wnk=1; 45 for(k=j;k<j+i;++k,wnk=wnk*wn%md) 46 { 47 t1=a[k];t2=a[k+i]*wnk%md; 48 a[k]+=t2; 49 (a[k]>=md) && (a[k]-=md); 50 a[k+i]=t1-t2; 51 (a[k+i]<0) && (a[k+i]+=md); 52 } 53 } 54 } 55 if(idx==-1) 56 { 57 ll ilen=poww(len,md-2); 58 for(i=0;i<len;++i) 59 a[i]=a[i]*ilen%md; 60 } 61 } 62 int f[N],g[N],t1[N]; 63 int n,n1; 64 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2^(ceil(log2(len))+1)(需要足够长用于临时存放元素) 65 { 66 g[0]=poww(f[0],md-2); 67 for(int i=2,j;i<(len<<1);i<<=1) 68 { 69 init(i<<1); 70 memcpy(t1,f,sizeof(int)*i); 71 memset(t1+i,0,sizeof(int)*i); 72 memset(g+(i>>1),0,sizeof(int)*(i+(i>>1))); 73 dft(t1,i<<1,1);dft(g,i<<1,1); 74 for(j=0;j<(i<<1);++j) 75 g[j]=ll(g[j])*(2+ll(md-g[j])*t1[j]%md)%md; 76 dft(g,i<<1,-1); 77 } 78 } 79 int main() 80 { 81 int i,t; 82 scanf("%d",&n);n1=n; 83 for(i=0;i<n;++i) 84 scanf("%d",g+i); 85 for(t=1;t<n;t<<=1); 86 n=t; 87 p_inv(g,f,n); 88 for(i=0;i<n1;++i) 89 printf("%d ",f[i]); 90 return 0; 91 }
资料:https://www.luogu.org/blog/user7035/duo-xiang-shi-zong-jie
里面有一个迷之优化(代码好像和文字表述的不一样,很玄学,看不懂,被坑了...)
牛顿迭代得到式子:$B(x) \equiv B'(x)-B'(x)(A(x)B'(x)-1) \pmod {x^n}$,其中B'(x)是上一次迭代的结果,B(x)是这一次的结果,A(x)是原多项式,n是这一次迭代得到的结果长度(设它是2的幂);设上一次迭代得到的结果长度为m=n/2
看右边的$A(x)B'(x)-1$,可以知道它第0到m-1项都是0,现在只需要求它与B'(x)的乘积的前n位,可以把它”左移“m位,这样它和B'(x)长度都只有m,因此只需要做长度为n(而不是2n)的NTT,然后再”右移”回去
如果与B'(x)相乘时不做长度为2n的NTT而做长度为n的NTT,那么可以发现结果刚好相当于正常结果(做长度为2n的NTT的结果取前n位)将前一半和后一半交换(未验证)
(可以直接用算A(x)B'(x)时求出的B'(x)的DFT)(当然这样NTT次数从3次变成了5次...)
版本2:(实测的确比版本1快)(另外把longlong都改成了unsignedlonglong)
1 #prag\ 2 ma GCC optimize(2) 3 #include<cstdio> 4 #include<algorithm> 5 #include<cstring> 6 #include<vector> 7 #include<cmath> 8 using namespace std; 9 #define fi first 10 #define se second 11 #define mp make_pair 12 #define pb push_back 13 typedef long long ll; 14 typedef unsigned long long ull; 15 const int md=998244353; 16 const int N=262144; 17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md)) 18 int rev[N]; 19 void init(int len) 20 { 21 int bit=0,i; 22 while((1<<(bit+1))<=len) ++bit; 23 for(i=0;i<len;++i) 24 rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1)); 25 } 26 ull poww(ull a,ull b) 27 { 28 ull base=a,ans=1; 29 for(;b;b>>=1,base=base*base%md) 30 if(b&1) 31 ans=ans*base%md; 32 return ans; 33 } 34 void dft(int *a,int len,int idx)//要求len为2的幂 35 { 36 int i,j,k,t1,t2;ull wn,wnk; 37 for(i=0;i<len;++i) 38 if(i<rev[i]) 39 swap(a[i],a[rev[i]]); 40 for(i=1;i<len;i<<=1) 41 { 42 wn=poww(idx==1?3:332748118,(md-1)/(i<<1)); 43 for(j=0;j<len;j+=(i<<1)) 44 { 45 wnk=1; 46 for(k=j;k<j+i;++k,wnk=wnk*wn%md) 47 { 48 t1=a[k];t2=a[k+i]*wnk%md; 49 a[k]+=t2; 50 (a[k]>=md) && (a[k]-=md); 51 a[k+i]=t1-t2; 52 (a[k+i]<0) && (a[k+i]+=md); 53 } 54 } 55 } 56 if(idx==-1) 57 { 58 ull ilen=poww(len,md-2); 59 for(i=0;i<len;++i) 60 a[i]=a[i]*ilen%md; 61 } 62 } 63 int t1[N],t2[N]; 64 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2^(ceil(log2(len))+1)(需要足够长用于临时存放元素) ;要求len是2的幂 65 { 66 g[0]=poww(f[0],md-2); 67 for(int i=2,j;i<(len<<1);i<<=1) 68 { 69 memcpy(t1,f,sizeof(int)*i); 70 memcpy(t2,g,sizeof(int)*(i>>1)); 71 memset(t2+(i>>1),0,sizeof(int)*(i>>1)); 72 init(i); 73 dft(t1,i,1);dft(t2,i,1); 74 for(j=0;j<i;++j) 75 t1[j]=ull(t1[j])*t2[j]%md; 76 dft(t1,i,-1); 77 for(j=0;j<(i>>1);++j) 78 t1[j]=t1[j+(i>>1)]; 79 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 80 dft(t1,i,1); 81 for(j=0;j<i;++j) 82 t1[j]=ull(t1[j])*t2[j]%md; 83 dft(t1,i,-1); 84 for(j=i>>1;j<i;++j) 85 delto(g[j],t1[j-(i>>1)]); 86 } 87 } 88 int f[N],g[N]; 89 int n,n1; 90 int main() 91 { 92 int i,t; 93 scanf("%d",&n);n1=n; 94 for(i=0;i<n;++i) 95 scanf("%d",g+i); 96 for(t=1;t<n;t<<=1); 97 n=t; 98 p_inv(g,f,n); 99 for(i=0;i<n1;++i) 100 printf("%d ",f[i]); 101 return 0; 102 }
版本3:基于此题版本2,改了疑似bug
1 #prag\ 2 ma GCC optimize(2) 3 #include<cstdio> 4 #include<algorithm> 5 #include<cstring> 6 #include<vector> 7 #include<cmath> 8 using namespace std; 9 #define fi first 10 #define se second 11 #define mp make_pair 12 #define pb push_back 13 typedef long long ll; 14 typedef unsigned long long ull; 15 const int md=998244353; 16 const int N=262144; 17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md)) 18 int rev[N]; 19 void init(int len) 20 { 21 int bit=0,i; 22 while((1<<(bit+1))<=len) ++bit; 23 for(i=0;i<len;++i) 24 rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1)); 25 } 26 ull poww(ull a,ull b) 27 { 28 ull base=a,ans=1; 29 for(;b;b>>=1,base=base*base%md) 30 if(b&1) 31 ans=ans*base%md; 32 return ans; 33 } 34 void dft(int *a,int len,int idx)//要求len为2的幂 35 { 36 int i,j,k,t1,t2;ull wn,wnk; 37 for(i=0;i<len;++i) 38 if(i<rev[i]) 39 swap(a[i],a[rev[i]]); 40 for(i=1;i<len;i<<=1) 41 { 42 wn=poww(idx==1?3:332748118,(md-1)/(i<<1)); 43 for(j=0;j<len;j+=(i<<1)) 44 { 45 wnk=1; 46 for(k=j;k<j+i;++k,wnk=wnk*wn%md) 47 { 48 t1=a[k];t2=a[k+i]*wnk%md; 49 a[k]+=t2; 50 (a[k]>=md) && (a[k]-=md); 51 a[k+i]=t1-t2; 52 (a[k+i]<0) && (a[k+i]+=md); 53 } 54 } 55 } 56 if(idx==-1) 57 { 58 ull ilen=poww(len,md-2); 59 for(i=0;i<len;++i) 60 a[i]=a[i]*ilen%md; 61 } 62 } 63 int t1[N],t2[N]; 64 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素) ;要求len是2的幂 65 { 66 g[0]=poww(f[0],md-2); 67 for(int i=2,j;i<(len<<1);i<<=1) 68 { 69 memcpy(t1,f,sizeof(int)*i); 70 memcpy(t2,g,sizeof(int)*(i>>1)); 71 memset(t2+(i>>1),0,sizeof(int)*(i>>1)); 72 init(i); 73 dft(t1,i,1);dft(t2,i,1); 74 for(j=0;j<i;++j) 75 t1[j]=ull(t1[j])*t2[j]%md; 76 dft(t1,i,-1); 77 for(j=0;j<(i>>1);++j) 78 t1[j]=t1[j+(i>>1)]; 79 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 80 dft(t1,i,1); 81 for(j=0;j<i;++j) 82 t1[j]=ull(t1[j])*t2[j]%md; 83 dft(t1,i,-1); 84 for(j=i>>1;j<i;++j) 85 g[j]=md-t1[j-(i>>1)]; 86 } 87 } 88 int f[N],g[N]; 89 int n,n1; 90 int main() 91 { 92 int i,t; 93 scanf("%d",&n);n1=n; 94 for(i=0;i<n;++i) 95 scanf("%d",g+i); 96 for(t=1;t<n;t<<=1); 97 n=t; 98 p_inv(g,f,n); 99 for(i=0;i<n1;++i) 100 printf("%d ",f[i]); 101 return 0; 102 }
来源:https://www.cnblogs.com/hehe54321/p/10353385.html