1
Fork 0
mirror of https://github.com/RGBCube/uutils-coreutils synced 2025-07-29 03:57:44 +00:00

Merge pull request #1529 from nbraud/factor/montgomery

factor: Faster modular arithmetic with the Montgomery transform
This commit is contained in:
Alex Lyon 2020-06-18 09:19:12 -07:00 committed by GitHub
commit 6105cce69a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 309 additions and 150 deletions

View file

@ -20,56 +20,19 @@
use std::env::{self, args};
use std::fs::File;
use std::io::Write;
use std::num::Wrapping;
use std::path::Path;
use std::u64::MAX as MAX_U64;
use self::sieve::Sieve;
#[cfg(test)]
use miller_rabin::is_prime;
#[cfg(test)]
#[path = "src/numeric.rs"]
mod numeric;
use numeric::inv_mod_u64;
mod sieve;
// extended Euclid algorithm
// precondition: a does not divide 2^64
fn inv_mod_u64(a: u64) -> Option<u64> {
let mut t = 0u64;
let mut newt = 1u64;
let mut r = 0u64;
let mut newr = a;
while newr != 0 {
let quot = if r == 0 {
// special case when we're just starting out
// This works because we know that
// a does not divide 2^64, so floor(2^64 / a) == floor((2^64-1) / a);
MAX_U64
} else {
r
} / newr;
let (tp, Wrapping(newtp)) = (newt, Wrapping(t) - (Wrapping(quot) * Wrapping(newt)));
t = tp;
newt = newtp;
let (rp, Wrapping(newrp)) = (newr, Wrapping(r) - (Wrapping(quot) * Wrapping(newr)));
r = rp;
newr = newrp;
}
if r > 1 {
// not invertible
return None;
}
Some(t)
}
#[cfg_attr(test, allow(dead_code))]
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
@ -95,7 +58,7 @@ fn main() {
let mut x = primes.next().unwrap();
for next in primes {
// format the table
let outstr = format!("({}, {}, {}),", x, inv_mod_u64(x).unwrap(), MAX_U64 / x);
let outstr = format!("({}, {}, {}),", x, inv_mod_u64(x), std::u64::MAX / x);
if cols + outstr.len() > MAX_WIDTH {
write!(file, "\n {}", outstr).unwrap();
cols = 4 + outstr.len();
@ -116,18 +79,12 @@ fn main() {
}
#[test]
fn test_inverter() {
let num = 10000;
let invs = Sieve::odd_primes().map(|x| inv_mod_u64(x).unwrap());
assert!(Sieve::odd_primes().zip(invs).take(num).all(|(x, y)| {
let Wrapping(z) = Wrapping(x) * Wrapping(y);
is_prime(x) && z == 1
}));
fn test_generator_isprime() {
assert_eq!(Sieve::odd_primes.take(10_000).all(is_prime));
}
#[test]
fn test_generator() {
fn test_generator_10001() {
let prime_10001 = Sieve::primes().skip(10_000).next();
assert_eq!(prime_10001, Some(104_743));
}

View file

@ -39,7 +39,7 @@ impl Factors {
}
fn add(&mut self, prime: u64, exp: u8) {
assert!(exp > 0);
debug_assert!(exp > 0);
let n = *self.f.get(&prime).unwrap_or(&0);
self.f.insert(prime, exp + n);
}
@ -47,6 +47,13 @@ impl Factors {
fn push(&mut self, prime: u64) {
self.add(prime, 1)
}
#[cfg(test)]
fn product(&self) -> u64 {
self.f
.iter()
.fold(1, |acc, (p, exp)| acc * p.pow(*exp as u32))
}
}
impl ops::MulAssign<Factors> for Factors {
@ -132,3 +139,22 @@ pub fn uumain(args: impl uucore::Args) -> i32 {
}
0
}
#[cfg(test)]
mod tests {
use super::factor;
#[test]
fn factor_recombines_small() {
assert!((1..10_000)
.map(|i| 2 * i + 1)
.all(|i| factor(i).product() == i));
}
#[test]
fn factor_recombines_overflowing() {
assert!((0..250)
.map(|i| 2 * i + 2u64.pow(32) + 1)
.all(|i| factor(i).product() == i));
}
}

View file

@ -23,9 +23,10 @@ impl Result {
// Deterministic Miller-Rabin primality-checking algorithm, adapted to extract
// (some) dividers; it will fail to factor strong pseudoprimes.
#[allow(clippy::many_single_char_names)]
pub(crate) fn test<A: Arithmetic>(n: u64) -> Result {
pub(crate) fn test<A: Arithmetic>(m: A) -> Result {
use self::Result::*;
let n = m.modulus();
if n < 2 {
return Pseudoprime;
}
@ -37,36 +38,41 @@ pub(crate) fn test<A: Arithmetic>(n: u64) -> Result {
let i = (n - 1).trailing_zeros();
let r = (n - 1) >> i;
for a in BASIS.iter() {
let a = a % n;
if a == 0 {
let one = m.one();
let minus_one = m.minus_one();
for _a in BASIS.iter() {
let _a = _a % n;
if _a == 0 {
break;
}
let a = m.from_u64(_a);
// x = a^r mod n
let mut x = A::pow(a, r, n);
let mut x = m.pow(a, r);
{
// y = ((x²)²...)² i times = x ^ (2ⁱ) = a ^ (r 2ⁱ) = x ^ (n - 1)
let mut y = x;
for _ in 0..i {
y = A::mul(y, y, n)
y = m.mul(y, y)
}
if y != 1 {
if y != one {
return Pseudoprime;
};
}
if x == 1 || x == n - 1 {
if x == one || x == minus_one {
break;
}
loop {
let y = A::mul(x, x, n);
if y == 1 {
return Composite(gcd(x - 1, n));
let y = m.mul(x, x);
if y == one {
return Composite(gcd(m.to_u64(x) - 1, m.modulus()));
}
if y == n - 1 {
if y == minus_one {
// This basis element is not a witness of `n` being composite.
// Keep looking.
break;
@ -81,10 +87,25 @@ pub(crate) fn test<A: Arithmetic>(n: u64) -> Result {
// Used by build.rs' tests
#[allow(dead_code)]
pub(crate) fn is_prime(n: u64) -> bool {
if n < 1 << 63 {
test::<Small>(n)
} else {
test::<Big>(n)
}
.is_prime()
test::<Montgomery>(Montgomery::new(n)).is_prime()
}
#[cfg(test)]
mod tests {
use super::is_prime;
const LARGEST_U64_PRIME: u64 = 0xFFFFFFFFFFFFFFC5;
#[test]
fn largest_prime() {
assert!(is_prime(LARGEST_U64_PRIME));
}
#[test]
fn first_primes() {
use crate::table::{NEXT_PRIME, P_INVS_U64};
for (p, _, _) in P_INVS_U64.iter() {
assert!(is_prime(*p), "{} reported composite", p);
}
assert!(is_prime(NEXT_PRIME));
}
}

View file

@ -8,10 +8,11 @@
// * that was distributed with this source code.
use std::mem::swap;
use std::num::Wrapping;
use std::u64::MAX as MAX_U64;
pub fn gcd(mut a: u64, mut b: u64) -> u64 {
// This is incorrectly reported as dead code,
// presumably when included in build.rs.
#[allow(dead_code)]
pub(crate) fn gcd(mut a: u64, mut b: u64) -> u64 {
while b > 0 {
a %= b;
swap(&mut a, &mut b);
@ -19,87 +20,243 @@ pub fn gcd(mut a: u64, mut b: u64) -> u64 {
a
}
pub(crate) trait Arithmetic {
fn add(a: u64, b: u64, modulus: u64) -> u64;
fn mul(a: u64, b: u64, modulus: u64) -> u64;
pub(crate) trait Arithmetic: Copy + Sized {
type I: Copy + Sized + Eq;
fn pow(mut a: u64, mut b: u64, m: u64) -> u64 {
let mut result = 1;
fn new(m: u64) -> Self;
fn modulus(&self) -> u64;
fn from_u64(&self, n: u64) -> Self::I;
fn to_u64(&self, n: Self::I) -> u64;
fn add(&self, a: Self::I, b: Self::I) -> Self::I;
fn mul(&self, a: Self::I, b: Self::I) -> Self::I;
fn pow(&self, mut a: Self::I, mut b: u64) -> Self::I {
let (_a, _b) = (a, b);
let mut result = self.one();
while b > 0 {
if b & 1 != 0 {
result = Self::mul(result, a, m);
result = self.mul(result, a);
}
a = Self::mul(a, a, m);
a = self.mul(a, a);
b >>= 1;
}
// Check that r (reduced back to the usual representation) equals
// a^b % n, unless the latter computation overflows
// Temporarily commented-out, as there u64::checked_pow is not available
// on the minimum supported Rust version, nor is an appropriate method
// for compiling the check conditionally.
//debug_assert!(self
// .to_u64(_a)
// .checked_pow(_b as u32)
// .map(|r| r % self.modulus() == self.to_u64(result))
// .unwrap_or(true));
result
}
fn one(&self) -> Self::I {
self.from_u64(1)
}
fn minus_one(&self) -> Self::I {
self.from_u64(self.modulus() - 1)
}
fn zero(&self) -> Self::I {
self.from_u64(0)
}
}
pub(crate) struct Big {}
#[derive(Clone, Copy, Debug)]
pub(crate) struct Montgomery {
a: u64,
n: u64,
}
impl Arithmetic for Big {
fn add(a: u64, b: u64, m: u64) -> u64 {
let Wrapping(msb_mod_m) = Wrapping(MAX_U64) - Wrapping(m) + Wrapping(1);
let msb_mod_m = msb_mod_m % m;
impl Montgomery {
/// computes x/R mod n efficiently
fn reduce(&self, x: u128) -> u64 {
debug_assert!(x < (self.n as u128) << 64);
// TODO: optimiiiiiiise
let Montgomery { a, n } = self;
let m = (x as u64).wrapping_mul(*a);
let nm = (*n as u128) * (m as u128);
let (xnm, overflow) = (x as u128).overflowing_add(nm); // x + n*m
debug_assert_eq!(xnm % (1 << 64), 0);
let Wrapping(res) = Wrapping(a) + Wrapping(b);
if b <= MAX_U64 - a {
res
// (x + n*m) / R
// in case of overflow, this is (2¹²⁸ + xnm)/2⁶⁴ - n = xnm/2⁶⁴ + (2⁶⁴ - n)
let y = (xnm >> 64) as u64 + if !overflow { 0 } else { n.wrapping_neg() };
if y >= *n {
y - n
} else {
(res + msb_mod_m) % m
y
}
}
// computes (a + b) % m using the russian peasant algorithm
// Only necessary when m >= 2^63; otherwise, just wastes time.
fn mul(mut a: u64, mut b: u64, m: u64) -> u64 {
// precompute 2^64 mod m, since we expect to wrap
let Wrapping(msb_mod_m) = Wrapping(MAX_U64) - Wrapping(m) + Wrapping(1);
let msb_mod_m = msb_mod_m % m;
let mut result = 0;
while b > 0 {
if b & 1 != 0 {
let Wrapping(next_res) = Wrapping(result) + Wrapping(a);
let next_res = next_res % m;
result = if result <= MAX_U64 - a {
next_res
} else {
(next_res + msb_mod_m) % m
};
}
let Wrapping(next_a) = Wrapping(a) << 1;
let next_a = next_a % m;
a = if a < 1 << 63 {
next_a
} else {
(next_a + msb_mod_m) % m
};
b >>= 1;
}
result
}
}
pub(crate) struct Small {}
impl Arithmetic for Montgomery {
// Montgomery transform, R=2⁶⁴
// Provides fast arithmetic mod n (n odd, u64)
type I = u64;
impl Arithmetic for Small {
// computes (a + b) % m using the russian peasant algorithm
// CAUTION: Will overflow if m >= 2^63
fn mul(mut a: u64, mut b: u64, m: u64) -> u64 {
let mut result = 0;
while b > 0 {
if b & 1 != 0 {
result = (result + a) % m;
}
a = (a << 1) % m;
b >>= 1;
}
result
fn new(n: u64) -> Self {
let a = inv_mod_u64(n).wrapping_neg();
debug_assert_eq!(n.wrapping_mul(a), 1_u64.wrapping_neg());
Montgomery { a, n }
}
fn add(a: u64, b: u64, m: u64) -> u64 {
(a + b) % m
fn modulus(&self) -> u64 {
self.n
}
fn from_u64(&self, x: u64) -> Self::I {
// TODO: optimise!
assert!(x < self.n);
let r = (((x as u128) << 64) % self.n as u128) as u64;
debug_assert_eq!(x, self.to_u64(r));
r
}
fn to_u64(&self, n: Self::I) -> u64 {
self.reduce(n as u128)
}
fn add(&self, a: Self::I, b: Self::I) -> Self::I {
let (r, overflow) = a.overflowing_add(b);
// In case of overflow, a+b = 2⁶⁴ + r = (2⁶⁴ - n) + r (working mod n)
let r = if !overflow {
r
} else {
r + self.n.wrapping_neg()
};
// Normalise to [0; n[
let r = if r < self.n { r } else { r - self.n };
// Check that r (reduced back to the usual representation) equals
// a+b % n
#[cfg(debug_assertions)]
{
let a_r = self.to_u64(a);
let b_r = self.to_u64(b);
let r_r = self.to_u64(r);
let r_2 = (((a_r as u128) + (b_r as u128)) % (self.n as u128)) as u64;
debug_assert_eq!(
r_r, r_2,
"[{}] = {} ≠ {} = {} + {} = [{}] + [{}] mod {}; a = {}",
r, r_r, r_2, a_r, b_r, a, b, self.n, self.a
);
}
r
}
fn mul(&self, a: Self::I, b: Self::I) -> Self::I {
let r = self.reduce((a as u128) * (b as u128));
// Check that r (reduced back to the usual representation) equals
// a*b % n
#[cfg(debug_assertions)]
{
let a_r = self.to_u64(a);
let b_r = self.to_u64(b);
let r_r = self.to_u64(r);
let r_2 = (((a_r as u128) * (b_r as u128)) % (self.n as u128)) as u64;
debug_assert_eq!(
r_r, r_2,
"[{}] = {} ≠ {} = {} * {} = [{}] * [{}] mod {}; a = {}",
r, r_r, r_2, a_r, b_r, a, b, self.n, self.a
);
}
r
}
}
// extended Euclid algorithm
// precondition: a is odd
pub(crate) fn inv_mod_u64(a: u64) -> u64 {
assert!(a % 2 == 1);
let mut t = 0u64;
let mut newt = 1u64;
let mut r = 0u64;
let mut newr = a;
while newr != 0 {
let quot = if r == 0 {
// special case when we're just starting out
// This works because we know that
// a does not divide 2^64, so floor(2^64 / a) == floor((2^64-1) / a);
std::u64::MAX
} else {
r
} / newr;
let newtp = t.wrapping_sub(quot.wrapping_mul(newt));
t = newt;
newt = newtp;
let newrp = r.wrapping_sub(quot.wrapping_mul(newr));
r = newr;
newr = newrp;
}
assert_eq!(r, 1);
t
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inverter() {
// All odd integers from 1 to 20 000
let mut test_values = (0..10_000u64).map(|i| 2 * i + 1);
assert!(test_values.all(|x| x.wrapping_mul(inv_mod_u64(x)) == 1));
}
#[test]
fn test_montgomery_add() {
for n in 0..100 {
let n = 2 * n + 1;
let m = Montgomery::new(n);
for x in 0..n {
let m_x = m.from_u64(x);
for y in 0..=x {
let m_y = m.from_u64(y);
println!("{n:?}, {x:?}, {y:?}", n = n, x = x, y = y);
assert_eq!((x + y) % n, m.to_u64(m.add(m_x, m_y)));
}
}
}
}
#[test]
fn test_montgomery_mult() {
for n in 0..100 {
let n = 2 * n + 1;
let m = Montgomery::new(n);
for x in 0..n {
let m_x = m.from_u64(x);
for y in 0..=x {
let m_y = m.from_u64(y);
assert_eq!((x * y) % n, m.to_u64(m.mul(m_x, m_y)));
}
}
}
}
#[test]
fn test_montgomery_roundtrip() {
for n in 0..100 {
let n = 2 * n + 1;
let m = Montgomery::new(n);
for x in 0..n {
let x_ = m.from_u64(x);
assert_eq!(x, m.to_u64(x_));
}
}
}
}

View file

@ -7,15 +7,15 @@ use crate::miller_rabin::Result::*;
use crate::numeric::*;
use crate::{miller_rabin, Factors};
fn find_divisor<A: Arithmetic>(n: u64) -> u64 {
fn find_divisor<A: Arithmetic>(n: A) -> u64 {
#![allow(clippy::many_single_char_names)]
let mut rand = {
let range = Uniform::new(1, n);
let range = Uniform::new(1, n.modulus());
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
move || range.sample(&mut rng)
move || n.from_u64(range.sample(&mut rng))
};
let quadratic = |a, b| move |x| A::add(A::mul(a, A::mul(x, x, n), n), b, n);
let quadratic = |a, b| move |x| n.add(n.mul(a, n.mul(x, x)), b);
loop {
let f = quadratic(rand(), rand());
@ -25,8 +25,12 @@ fn find_divisor<A: Arithmetic>(n: u64) -> u64 {
loop {
x = f(x);
y = f(f(y));
let d = gcd(n, max(x, y) - min(x, y));
if d == n {
let d = {
let _x = n.to_u64(x);
let _y = n.to_u64(y);
gcd(n.modulus(), max(_x, _y) - min(_x, _y))
};
if d == n.modulus() {
// Failure, retry with a different quadratic
break;
} else if d > 1 {
@ -39,11 +43,8 @@ fn find_divisor<A: Arithmetic>(n: u64) -> u64 {
fn _factor<A: Arithmetic>(mut num: u64) -> Factors {
// Shadow the name, so the recursion automatically goes from “Big” arithmetic to small.
let _factor = |n| {
if n < 1 << 63 {
_factor::<Small>(n)
} else {
_factor::<A>(n)
}
// TODO: Optimise with 32 and 64b versions
_factor::<A>(n)
};
let mut factors = Factors::new();
@ -51,7 +52,8 @@ fn _factor<A: Arithmetic>(mut num: u64) -> Factors {
return factors;
}
match miller_rabin::test::<A>(num) {
let n = A::new(num);
match miller_rabin::test::<A>(n) {
Prime => {
factors.push(num);
return factors;
@ -65,16 +67,12 @@ fn _factor<A: Arithmetic>(mut num: u64) -> Factors {
Pseudoprime => {}
};
let divisor = find_divisor::<A>(num);
let divisor = find_divisor::<A>(n);
factors *= _factor(divisor);
factors *= _factor(num / divisor);
factors
}
pub(crate) fn factor(n: u64) -> Factors {
if n < 1 << 63 {
_factor::<Small>(n)
} else {
_factor::<Big>(n)
}
_factor::<Montgomery>(n)
}