From 33e18b4cd390311782609f3b46d0741430d7bdb4 Mon Sep 17 00:00:00 2001 From: nicoo Date: Sat, 30 May 2020 10:11:05 +0200 Subject: [PATCH] factor::numeric::Montgomery: Add debug assertions In debug mode, checks that all arithmetic operations coincide with the plain-u64 versions, as long as the latter does not overflow. --- src/uu/factor/src/numeric.rs | 60 +++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/src/uu/factor/src/numeric.rs b/src/uu/factor/src/numeric.rs index d9f0ed7c9..ce25f81bf 100644 --- a/src/uu/factor/src/numeric.rs +++ b/src/uu/factor/src/numeric.rs @@ -29,7 +29,8 @@ pub(crate) trait Arithmetic: Copy + Sized { fn mul(&self, a: Self::I, b: Self::I) -> Self::I; fn pow(&self, mut a: Self::I, mut b: u64) -> Self::I { - let mut result = self.from_u64(1u64); + let (_a, _b) = (a, b); + let mut result = self.one(); while b > 0 { if b & 1 != 0 { result = self.mul(result, a); @@ -37,6 +38,15 @@ pub(crate) trait Arithmetic: Copy + Sized { a = self.mul(a, a); b >>= 1; } + + // Check that r (reduced back to the usual representation) equals + // a^b % n, unless the latter computation overflows + debug_assert!(self + .to_u64(_a) + .checked_pow(_b as u32) + .map(|r| r % self.modulus() == self.to_u64(result)) + .unwrap_or(true)); + result } @@ -79,10 +89,9 @@ impl Arithmetic for Montgomery { type I = Wrapping; fn new(n: u64) -> Self { - Montgomery { - a: inv_mod_u64(n).wrapping_neg(), - n, - } + let a = inv_mod_u64(n).wrapping_neg(); + debug_assert_eq!(n.wrapping_mul(a), 1_u64.wrapping_neg()); + Montgomery { a, n } } fn modulus(&self) -> u64 { @@ -91,7 +100,10 @@ impl Arithmetic for Montgomery { fn from_u64(&self, x: u64) -> Self::I { // TODO: optimise! - Wrapping((((x as u128) << 64) % self.n as u128) as u64) + assert!(x < self.n); + let r = Wrapping((((x as u128) << 64) % self.n as u128) as u64); + debug_assert_eq!(x, self.to_u64(r)); + r } fn to_u64(&self, n: Self::I) -> u64 { @@ -99,11 +111,43 @@ impl Arithmetic for Montgomery { } fn add(&self, a: Self::I, b: Self::I) -> Self::I { - a + b + let r = a + b; + + // Check that r (reduced back to the usual representation) equals + // a+b % n + #[cfg(debug_assertions)] + { + let a_r = self.to_u64(a); + let b_r = self.to_u64(b); + let r_r = self.to_u64(r); + let r_2 = (((a_r as u128) + (b_r as u128)) % (self.n as u128)) as u64; + debug_assert_eq!( + r_r, r_2, + "[{}] = {} ≠ {} = {} + {} = [{}] + [{}] mod {}; a = {}", + r, r_r, r_2, a_r, b_r, a, b, self.n, self.a + ); + } + r } fn mul(&self, a: Self::I, b: Self::I) -> Self::I { - Wrapping(self.reduce((a * b).0)) + let r = Wrapping(self.reduce((a * b).0)); + + // Check that r (reduced back to the usual representation) equals + // a*b % n + #[cfg(debug_assertions)] + { + let a_r = self.to_u64(a); + let b_r = self.to_u64(b); + let r_r = self.to_u64(r); + let r_2 = (((a_r as u128) * (b_r as u128)) % (self.n as u128)) as u64; + debug_assert_eq!( + r_r, r_2, + "[{}] = {} ≠ {} = {} * {} = [{}] * [{}] mod {}; a = {}", + r, r_r, r_2, a_r, b_r, a, b, self.n, self.a + ); + } + r } }