diff --git a/Libraries/LibCrypto/BigInt/UnsignedBigInteger.cpp b/Libraries/LibCrypto/BigInt/UnsignedBigInteger.cpp index d22ce57a19..987abce7bd 100644 --- a/Libraries/LibCrypto/BigInt/UnsignedBigInteger.cpp +++ b/Libraries/LibCrypto/BigInt/UnsignedBigInteger.cpp @@ -64,11 +64,47 @@ UnsignedBigInteger UnsignedBigInteger::add(const UnsignedBigInteger& other) return result; } +UnsignedBigInteger UnsignedBigInteger::sub(const UnsignedBigInteger& other) +{ + UnsignedBigInteger result; + + if (*this < other) { + dbg() << "WARNING: bigint subtraction creates a negative number!"; + return UnsignedBigInteger::create_invalid(); + } + + u8 borrow = 0; + for (size_t i = 0; i < other.length(); ++i) { + ASSERT(!(borrow == 1 && m_words[i] == 0)); + + if (m_words[i] - borrow < other.m_words[i]) { + u64 after_borrow = static_cast(m_words[i] - borrow) + (UINT32_MAX + 1); + result.m_words.append(static_cast(after_borrow - static_cast(other.m_words[i]))); + borrow = 1; + } else { + result.m_words.append(m_words[i] - borrow - other.m_words[i]); + borrow = 0; + } + } + + for (size_t i = other.length(); i < length(); ++i) { + ASSERT(!(borrow == 1 && m_words[i] == 0)); + result.m_words.append(m_words[i] - borrow); + borrow = 0; + } + + return result; +} + bool UnsignedBigInteger::operator==(const UnsignedBigInteger& other) const { if (trimmed_length() != other.trimmed_length()) { return false; } + if (is_invalid() != other.is_invalid()) { + return false; + } + for (size_t i = 0; i < trimmed_length(); ++i) { if (m_words[i] != other.words()[i]) return false; @@ -76,6 +112,23 @@ bool UnsignedBigInteger::operator==(const UnsignedBigInteger& other) const return true; } +bool UnsignedBigInteger::operator<(const UnsignedBigInteger& other) const +{ + if (trimmed_length() < other.trimmed_length()) { + return true; + } + if (trimmed_length() > other.trimmed_length()) { + return false; + } + + size_t length = trimmed_length(); + if (length == 0) { + return false; + } + + return m_words[length - 1] < other.m_words[length - 1]; +} + size_t UnsignedBigInteger::trimmed_length() const { size_t num_leading_zeroes = 0; @@ -86,4 +139,11 @@ size_t UnsignedBigInteger::trimmed_length() const return length() - num_leading_zeroes; } +UnsignedBigInteger UnsignedBigInteger::create_invalid() +{ + UnsignedBigInteger invalid(0); + invalid.invalidate(); + return invalid; +} + } diff --git a/Libraries/LibCrypto/BigInt/UnsignedBigInteger.h b/Libraries/LibCrypto/BigInt/UnsignedBigInteger.h index cf3b280c8c..17aea1e5a1 100644 --- a/Libraries/LibCrypto/BigInt/UnsignedBigInteger.h +++ b/Libraries/LibCrypto/BigInt/UnsignedBigInteger.h @@ -35,9 +35,12 @@ public: UnsignedBigInteger(u32 x) { m_words.append(x); } UnsignedBigInteger() {} + static UnsignedBigInteger create_invalid(); + const AK::Vector& words() const { return m_words; } UnsignedBigInteger add(const UnsignedBigInteger& other); + UnsignedBigInteger sub(const UnsignedBigInteger& other); size_t length() const { return m_words.size(); } @@ -45,15 +48,26 @@ public: size_t trimmed_length() const; bool operator==(const UnsignedBigInteger& other) const; + bool operator<(const UnsignedBigInteger& other) const; + + void invalidate() { m_is_invalid = true; } + bool is_invalid() const { return m_is_invalid; } private: AK::Vector m_words; + + // Used to indicate a negative result, or a result of an invalid operation + bool m_is_invalid { false }; }; } inline const LogStream& operator<<(const LogStream& stream, const Crypto::UnsignedBigInteger value) { + if (value.is_invalid()) { + stream << "Invalid BigInt"; + return stream; + } for (int i = value.length() - 1; i >= 0; --i) { stream << value.words()[i] << "|"; } diff --git a/Userland/test-crypto.cpp b/Userland/test-crypto.cpp index b47988a643..127bf55989 100644 --- a/Userland/test-crypto.cpp +++ b/Userland/test-crypto.cpp @@ -304,6 +304,7 @@ void hmac_sha512_test_process(); void bigint_test_fibo500(); void bigint_addition_edgecases(); +void bigint_subtraction(); int aes_cbc_tests() { @@ -797,21 +798,26 @@ int bigint_tests() { bigint_test_fibo500(); bigint_addition_edgecases(); + bigint_subtraction(); return 0; } +Crypto::UnsignedBigInteger bigint_fibonacci(size_t n) +{ + Crypto::UnsignedBigInteger num1(0); + Crypto::UnsignedBigInteger num2(1); + for (size_t i = 0; i < n; ++i) { + Crypto::UnsignedBigInteger t = num1.add(num2); + num2 = num1; + num1 = t; + } + return num1; +} void bigint_test_fibo500() { { I_TEST((BigInteger | Fibonacci500)); - Crypto::UnsignedBigInteger num1(0); - Crypto::UnsignedBigInteger num2(1); - for (int i = 0; i < 500; ++i) { - Crypto::UnsignedBigInteger t = num1.add(num2); - num2 = num1; - num1 = t; - } - bool pass = (num1.words() == AK::Vector { 315178285, 505575602, 1883328078, 125027121, 3649625763, 347570207, 74535262, 3832543808, 2472133297, 1600064941, 65273441 }); + bool pass = (bigint_fibonacci(500).words() == AK::Vector { 315178285, 505575602, 1883328078, 125027121, 3649625763, 347570207, 74535262, 3832543808, 2472133297, 1600064941, 65273441 }); if (pass) PASS; @@ -838,3 +844,53 @@ void bigint_addition_edgecases() } } } + +void bigint_subtraction() +{ + { + I_TEST((BigInteger | Simple Subtraction 1)); + Crypto::UnsignedBigInteger num1(80); + Crypto::UnsignedBigInteger num2(70); + + if (num1.sub(num2) == Crypto::UnsignedBigInteger(10)) { + PASS; + } else { + FAIL(Incorrect Result); + } + } + { + I_TEST((BigInteger | Simple Subtraction 2)); + Crypto::UnsignedBigInteger num1(50); + Crypto::UnsignedBigInteger num2(70); + + if (num1.sub(num2).is_invalid()) { + PASS; + } else { + FAIL(Incorrect Result); + } + } + { + I_TEST((BigInteger | Subtraction with borrow)); + Crypto::UnsignedBigInteger num1(UINT32_MAX); + Crypto::UnsignedBigInteger num2(1); + Crypto::UnsignedBigInteger num3 = num1.add(num2); + Crypto::UnsignedBigInteger result = num3.sub(num2); + if (result == num1) { + PASS; + } else { + FAIL(Incorrect Result); + } + } + { + I_TEST((BigInteger | Subtraction with large numbers)); + Crypto::UnsignedBigInteger num1 = bigint_fibonacci(343); + Crypto::UnsignedBigInteger num2 = bigint_fibonacci(218); + Crypto::UnsignedBigInteger result = num1.sub(num2); + if ((result.add(num2) == num1) + && (result.words() == Vector { 811430588, 2958904896, 1130908877, 2830569969, 3243275482, 3047460725, 774025231, 7990 })) { + PASS; + } else { + FAIL(Incorrect Result); + } + } +}