# 三模数 NTT
常数大、速度慢、精度高是它的特点。
在考虑三模数 NTT 之前先考虑一下中国剩余定理吧。
已知
求 。
有:
一点疑惑的解答(自言自语):
因为 ,所以 。又因为 ,所以 。所以 ,所以 最小为 ,即 。
进入正题:
所谓的三模数 NTT 指的是 以 为模数(经典 NTT 模数,原根均为 )分别进行 NTT,最后用上文的计算方式计算即可。
因为以上三个模数的乘积很大,答案即使不取模也不会大于该数,所以上式的 就是原答案,直接对题目给出的模数取模即可。
#include <bits/stdc++.h> | |
using namespace std; | |
using ll = long long; | |
#define Big __int128 | |
const int N=3e5+1; | |
const ll mo1=998244353,mo2=1004535809,mo3=469762049,G=3; | |
inline Big Ksm(Big x,Big y,ll mo){ | |
Big res=1; | |
for(;y;y>>=1,x=x*x%mo) | |
if(y&1)res=res*x%mo; | |
return res; | |
} | |
ll MOD; | |
const ll inv1=Ksm(mo1,mo2-2,mo2),inv2=Ksm(mo1*mo2%mo3,mo3-2,mo3); | |
struct Int{ | |
ll a,b,c; | |
Int(ll _x=0){a=b=c=_x;} | |
Int(ll _a,ll _b,ll _c){a=_a,b=_b,c=_c;} | |
inline Int operator + (const Int &x){return Int((ll)(a+x.a)%mo1,(ll)(b+x.b)%mo2,(ll)(c+x.c)%mo3);} | |
inline Int operator - (const Int &x){return Int((ll)(a-x.a+mo1)%mo1,(ll)(b-x.b+mo2)%mo2,(ll)(c-x.c+mo3)%mo3);} | |
inline Int operator * (const Int &x){return Int((ll)a*x.a%mo1,(ll)b*x.b%mo2,(ll)c*x.c%mo3);} | |
inline Int operator * (ll x){return Int((ll)a*x%mo1,(ll)b*x%mo2,(ll)c*x%mo3);} | |
void mulinv(ll x){ | |
a=a*Ksm(x,mo1-2,mo1)%mo1; | |
b=b*Ksm(x,mo2-2,mo2)%mo2; | |
c=c*Ksm(x,mo3-2,mo3)%mo3; | |
} | |
void inv(){ | |
a=Ksm(a,mo1-2,mo1)%mo1; | |
b=Ksm(b,mo2-2,mo2)%mo2; | |
c=Ksm(c,mo3-2,mo3)%mo3; | |
} | |
ll gettrue(){ | |
Big x=(Big)(b-a+mo2)%mo2*inv1%mo2*(Big)mo1+(Big)a; | |
return (((Big)(c-x%mo3+mo3)%mo3*inv2%mo3*(mo1%MOD*mo2%MOD)%MOD+x%MOD)%MOD+MOD)%MOD; | |
} | |
}; // mtt | |
int rev[N]; | |
Int w[N]; | |
void NTT(Int *a,int Len,bool type){ | |
for(int i=0;i<Len;i++){ | |
rev[i]=(rev[i>>1]>>1)+(i&1?Len>>1:0); | |
if(rev[i]>i)swap(a[rev[i]],a[i]); | |
} | |
for(int d=1;d<Len;d<<=1){ | |
Int W=Int(Ksm(G,(mo1-1)/(d*2),mo1),Ksm(G,(mo2-1)/(d*2),mo2),Ksm(G,(mo3-1)/(d*2),mo3)); | |
if(type)W.inv(); | |
w[0]=Int(1); for(int i=1;i<d;i++)w[i]=w[i-1]*W; | |
for(int fir=0;fir<Len;fir+=d<<1){ | |
int sec=fir+d; | |
for(int i=0;i<d;i++){ | |
Int a0=a[fir+i],a1=w[i]*a[sec+i]; | |
a[fir+i]=a0+a1,a[sec+i]=a0-a1; | |
} | |
} | |
} | |
if(type){for(int i=0;i<Len;i++)a[i].mulinv(Len);} | |
} | |
int n,m; | |
Int f[N],g[N]; | |
int main(){ | |
cin>>n>>m>>MOD; | |
for(int i=0,x;i<=n;i++)cin>>x,x%=MOD,f[i]=Int(x); | |
for(int i=0,x;i<=m;i++)cin>>x,x%=MOD,g[i]=Int(x); | |
int Len=1; | |
while(Len<=(n+m+4))Len<<=1; | |
NTT(f,Len,0),NTT(g,Len,0); | |
for(int i=0;i<Len;i++)f[i]=f[i]*g[i]; | |
NTT(f,Len,1); | |
for(int i=0;i<=n+m;i++)cout<<f[i].gettrue()<<' '; | |
cout<<'\n'; | |
return 0; | |
} |
# 拆系数 FFT
常数小,速度快,精度低( 信仰跑)是它的特色。
如果直接对原数列进行 FFT 的话会炸精度的。考虑拆系数,即 ()。
那么:
如果直接计算的话需要四次 dft,三次 idft,和九次 ntt 的三模数 NTT 差距并不大。
考虑优化,然而 dft/idft 中有什么地方没有用到捏?虚部!考虑将 ,,, 合并在一起进行 dft。
设:
有:
我们可以通过 和 的加减得到我们想要的系数。
注意: 是先除以二再取整!!!取整后一定要先取模!!!(代码 行)。
#include <bits/stdc++.h> | |
#define poly vector<int> | |
using namespace std; | |
const int N=5e5+11; | |
int mo; | |
const int base=32768; | |
namespace Poly{ | |
using db = long double; | |
const db pi=acos(-1); | |
struct cp{ | |
db x,y; | |
cp operator + (const cp &a){return {x+a.x,y+a.y};} | |
cp operator - (const cp &a){return {x-a.x,y-a.y};} | |
cp operator * (const cp &a){return {x*a.x-y*a.y,x*a.y+y*a.x};} | |
}; | |
cp w[N]; int rev[N]; | |
void init_rev(int Len){ | |
for(int i=0;i<Len;i++) | |
rev[i]=(rev[i>>1]>>1)+(i&1?Len>>1:0); | |
} | |
void FFT(cp *a,int Len,bool type){ | |
for(int i=0;i<Len;i++)if(rev[i]>i)swap(a[rev[i]],a[i]); | |
for(int d=1;d<Len;d<<=1){ | |
cp W={cos(pi/d),sin(pi/d)}; | |
if(type)W.y=-W.y; | |
w[0]={1,0}; | |
for(int i=1;i<d;i++)w[i]=w[i-1]*W; | |
for(int fir=0;fir<Len;fir+=d<<1){ | |
int sec=fir+d; | |
for(int i=0;i<d;i++){ | |
cp a0=a[fir+i],a1=w[i]*a[sec+i]; | |
a[fir+i]=a0+a1,a[sec+i]=a0-a1; | |
} | |
} | |
} | |
if(type)for(int i=0;i<Len;i++)a[i].x/=Len,a[i].y/=Len; | |
} | |
cp f[N],g[N],e[N]; | |
long long C(db x){return (long long)(x/2.+0.5)%mo;} // important!!! | |
poly mul(poly x,poly y){ | |
int tot=x.size()+y.size()-1,Len=1; | |
while(Len<=(tot+2))Len<<=1; | |
init_rev(Len); | |
for(int i=0;i<=Len;i++)f[i]=g[i]=e[i]={0,0}; | |
for(int i=0;i<x.size();i++){ | |
int a0=x[i]/base,a1=x[i]%base; | |
f[i]={a0,a1},g[i]={a0,-a1}; | |
} | |
for(int i=0;i<y.size();i++){ | |
int b0=y[i]/base,b1=y[i]%base; | |
e[i]={b0,b1}; | |
} | |
FFT(f,Len,0),FFT(g,Len,0),FFT(e,Len,0); | |
for(int i=0;i<Len;i++)f[i]=f[i]*e[i],g[i]=g[i]*e[i]; | |
FFT(f,Len,1),FFT(g,Len,1); | |
poly ret(tot,0); | |
for(int i=0;i<tot;i++){ | |
ret[i]=1ll*base*base%mo*(C(f[i].x+g[i].x))%mo; | |
ret[i]+=1ll*base*((C(f[i].y+g[i].y))+(C(f[i].y-g[i].y)))%mo; | |
ret[i]%=mo; | |
ret[i]+=(C(g[i].x-f[i].x))%mo; | |
ret[i]%=mo; | |
} | |
return ret; | |
} | |
} using Poly::mul; | |
int a[N],n,m; | |
int main(){ | |
cin>>n>>m>>mo; | |
poly a(n+1,0),b(m+1,0); | |
for(int i=0;i<=n;i++) cin>>a[i]; | |
for(int i=0;i<=m;i++) cin>>b[i]; | |
a=mul(a,b); | |
for(int i:a)cout<<i<<' '; | |
return 0; | |
} |