乘法逆元求法小结
逆元的定义
对于整数a和p,如果a与p互质,我们定义x为a的逆元,计为a−1,满足a∗x≡1 (mod p)。
求解乘法逆元
1 拓展欧几里得算法(egcd)
拓展的欧几里得算法原理是基于贝祖定理,对非零整数a,b,存在r和s使得
ar+bs=gcd(a,b)
因为a,p互质,所以ar+bp=gcd(a,p)=1,而由方程a∗x≡1 (mod p),等价于ax=kp+1(ax+(-k)p=1),k为一个整数,所以ar+bp=1中的r就相当于是x,也就是乘法逆元。求r即求得乘法逆元。
python代码如下
1 2 3 4 5 6 7 8 9 10
| def mul_reverse_element(a,m): //a,m 该算法求得a在mod m下的乘法逆元 r0,r1,s0,s1=1,0,0,1 b=m while(b): q,a,b=a//b,b,a%b r0,r1=r1,r0-q*r1 if a!=1: return -1; //逆元不存在 else: return (r0+m)%m; //保证乘法逆元为正的
|
2 快速幂+费尔马小定理
费尔马小定理:若p为素数,a为正整数,且a,p互质,有:
ap−1≡1 (mod p)
所以
a∗x≡1 ≡ap−1 (mod p)
x≡ap−2 (mod p)
接着用快速幂计算ap−2 mod p即可
代码如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
|
#include<iostream> using namespace std; long long int quick_pow(long long int x,long long int pow,long long int r){ long long int ans=1,base=x; while(pow>0){ if(pow&1){ ans*=base; ans%=r; } base*=base; base%=r; pow>>=1; } return ans%r; } int main() { long long int a,p; int n=10; while(n--){ cin>>a>>p; cout<<quick_pow(a,p-2,p)<<endl; } return 0; }
|
3 阶乘法逆元求解O(n)复杂度
首先设f[n]为n的阶乘,ans[n]为n的逆元,给出结论:
ans[n]=ans[n!]∗(n−1)!
ans[(n−1)!]=ans[n!]∗n
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
|
#include<bits/stdc++.h> using namespace std; typedef long long ll; ll quick_pow(ll x,ll pow,ll r){ ll ans=1,base=x; while(pow>0){ if(pow&1){ ans*=base; ans%=r; } base*=base; base%=r; pow>>=1; } return ans%r; } ll f[3000050]; ll ans[3000050]; int inv[5050]; int main() { int n,p; cin>>n>>p; f[0]=1; for(ll i=1;i<=n;i++){ f[i]=(f[i-1]*i)%p; } ll last_reverse=quick_pow(f[n],p-2,p); ll tem; for(int i=n;i>=1;i--){ ans[i]=(last_reverse*f[i-1])%p; tem=(last_reverse*i)%p; last_reverse=tem; } for(int i=1;i<=n;i++){ printf("%lld\n",ans[i]); } return 0; }
|
4 线性递推算法
对于等式a∗x≡1 (mod p),设p=k*x+r(1<r<x),等式同时mod p
k∗x+r≡0 (mod p)
同乘以i−1,r−1得到:
k∗r−1+i−1≡0 (mod p)
i−1≡−k∗r−1 mod( p)
i−1≡−⌊p/i⌋∗(p mod i)−1 (mod p)
代码如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| #include<bits/stdc++.h> using namespace std; long long int inv[3000500]; int main() { inv[1]=1; int n; int p; cin>>n>>p; cout<<inv[1]<<endl; for(int i=2;i<=n;i++){ inv[i]=(p-p/i)*inv[p%i]%p; printf("%lld\n",inv[i]); } return 0; }
|
仅作为学习记录 欢迎指正.