diff --git a/AK/StdLibExtraDetails.h b/AK/StdLibExtraDetails.h index 6790de3f9c..47680425be 100644 --- a/AK/StdLibExtraDetails.h +++ b/AK/StdLibExtraDetails.h @@ -26,7 +26,7 @@ using FalseType = IntegralConstant; using TrueType = IntegralConstant; template -using AddConst = const T; +using AddConst = T const; template struct __AddConstToReferencedType { diff --git a/Tests/LibCrypto/TestASN1.cpp b/Tests/LibCrypto/TestASN1.cpp index 730e8db2f2..57d8f45de6 100644 --- a/Tests/LibCrypto/TestASN1.cpp +++ b/Tests/LibCrypto/TestASN1.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #define EXPECT_DATETIME(sv, y, mo, d, h, mi, s) \ @@ -163,3 +164,77 @@ TEST_CASE(test_generalized_nonexistent_dates) (void)Crypto::ASN1::parse_generalized_time("19230222040506Z"sv); // Gregorian switch; Greece (void)Crypto::ASN1::parse_generalized_time("19261224040506Z"sv); // Gregorian switch; Turkey } + +TEST_CASE(test_encoder_primitives) +{ + auto roundtrip_value = [](auto value) { + Crypto::ASN1::Encoder encoder; + MUST(encoder.write(value)); + auto encoded = encoder.finish(); + Crypto::ASN1::Decoder decoder(encoded); + auto decoded = MUST(decoder.read()); + EXPECT_EQ(decoded, value); + }; + + roundtrip_value(false); + roundtrip_value(true); + + roundtrip_value(Crypto::UnsignedBigInteger { 0 }); + roundtrip_value(Crypto::UnsignedBigInteger { 1 }); + roundtrip_value(Crypto::UnsignedBigInteger { 2 }.shift_left(128)); + roundtrip_value(Crypto::UnsignedBigInteger { 2 }.shift_left(256)); + + roundtrip_value(Vector { 1, 2, 840, 113549, 1, 1, 1 }); + roundtrip_value(Vector { 1, 2, 840, 113549, 1, 1, 11 }); + + roundtrip_value(ByteString { "Hello, World!\n" }); + + roundtrip_value(nullptr); + + roundtrip_value(Crypto::ASN1::BitStringView { { { 0x00, 0x01, 0x02, 0x03 } }, 3 }); +} + +TEST_CASE(test_encoder_constructed) +{ + Crypto::ASN1::Encoder encoder; + /* + * RSAPrivateKey ::= SEQUENCE { + * version Version, -- Version ::= INTEGER { two-prime(0), multi(1) } + * modulus INTEGER, -- n + * publicExponent INTEGER, -- e + * privateExponent INTEGER, -- d + * prime1 INTEGER, -- p + * prime2 INTEGER, -- q + * exponent1 INTEGER, -- d mod (p-1) + * exponent2 INTEGER, -- d mod (q-1) + * coefficient INTEGER, -- (inverse of q) mod p + * otherPrimeInfos OtherPrimeInfos OPTIONAL + * } + */ + (void)encoder.write_constructed(Crypto::ASN1::Class::Universal, Crypto::ASN1::Kind::Sequence, [&] { + MUST(encoder.write(0u)); // version + MUST(encoder.write(0x1234u)); // modulus + MUST(encoder.write(0x10001u)); // publicExponent + MUST(encoder.write(0x5678u)); // privateExponent + MUST(encoder.write(0x9abcu)); // prime1 + MUST(encoder.write(0xdef0u)); // prime2 + MUST(encoder.write(0x1234u)); // exponent1 + MUST(encoder.write(0x5678u)); // exponent2 + MUST(encoder.write(0x9abcu)); // coefficient + }); + auto encoded = encoder.finish(); + Crypto::ASN1::Decoder decoder(encoded); + MUST(decoder.enter()); // Sequence + EXPECT_EQ(MUST(decoder.read()), 0u); // version + EXPECT_EQ(MUST(decoder.read()), 0x1234u); // modulus + EXPECT_EQ(MUST(decoder.read()), 0x10001u); // publicExponent + EXPECT_EQ(MUST(decoder.read()), 0x5678u); // privateExponent + EXPECT_EQ(MUST(decoder.read()), 0x9abcu); // prime1 + EXPECT_EQ(MUST(decoder.read()), 0xdef0u); // prime2 + EXPECT_EQ(MUST(decoder.read()), 0x1234u); // exponent1 + EXPECT_EQ(MUST(decoder.read()), 0x5678u); // exponent2 + EXPECT_EQ(MUST(decoder.read()), 0x9abcu); // coefficient + EXPECT(decoder.eof()); // no otherPrimeInfos + MUST(decoder.leave()); // Sequence + EXPECT(decoder.eof()); // no other data +} diff --git a/Userland/Libraries/LibCrypto/ASN1/DER.cpp b/Userland/Libraries/LibCrypto/ASN1/DER.cpp index 262a12454b..a8717b6f60 100644 --- a/Userland/Libraries/LibCrypto/ASN1/DER.cpp +++ b/Userland/Libraries/LibCrypto/ASN1/DER.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include #include #include @@ -234,6 +235,206 @@ ErrorOr Decoder::leave() return {}; } +ErrorOr Encoder::write_tag(Class class_, Type type, Kind kind) +{ + auto class_byte = to_underlying(class_); + auto type_byte = to_underlying(type); + auto kind_byte = to_underlying(kind); + + auto byte = class_byte | type_byte | kind_byte; + if (kind_byte > 0x1f) { + auto high = kind_byte >> 7; + byte = class_byte | type_byte | 0x1f; + TRY(write_byte(byte)); + byte = (kind_byte & 0x7f) | high; + } + + return write_byte(byte); +} + +ErrorOr Encoder::write_byte(u8 byte) +{ + return write_bytes({ &byte, 1 }); +} + +ErrorOr Encoder::write_length(size_t value) +{ + if (value < 0x80) + return write_byte(value); + + size_t size = ceil_div(AK::ceil_log2(value), 3ul); + TRY(write_byte(0x80 | size)); + + for (size_t i = 0; i < size; i++) { + auto shift = (size - i - 1) * 8; + auto byte = (value >> shift) & 0xff; + TRY(write_byte(byte)); + } + + return {}; +} + +ErrorOr Encoder::write_bytes(ReadonlyBytes bytes) +{ + auto output = TRY(m_buffer_stack.last().get_bytes_for_writing(bytes.size())); + bytes.copy_to(output); + return {}; +} + +ErrorOr Encoder::write_boolean(bool value, Optional class_override, Optional kind_override) +{ + auto class_ = class_override.value_or(Class::Universal); + auto type = Type::Primitive; + auto kind = kind_override.value_or(Kind::Boolean); + + TRY(write_tag(class_, type, kind)); + TRY(write_length(1)); + return write_byte(value ? 0xff : 0x00); +} + +ErrorOr Encoder::write_arbitrary_sized_integer(UnsignedBigInteger const& value, Optional class_override, Optional kind_override) +{ + auto class_ = class_override.value_or(Class::Universal); + auto type = Type::Primitive; + auto kind = kind_override.value_or(Kind::Integer); + TRY(write_tag(class_, type, kind)); + + auto max_byte_size = max(1ull, value.length() * UnsignedBigInteger::BITS_IN_WORD / 8); // At minimum, we need one byte to encode 0. + ByteBuffer buffer; + auto output = TRY(buffer.get_bytes_for_writing(max_byte_size)); + auto size = value.export_data(output); + // DER does not allow empty integers, encode a zero if the exported size is zero. + if (size == 0) { + output[0] = 0; + size = 1; + } + + // Chop off the leading zeros + if constexpr (AK::HostIsLittleEndian) { + while (size > 1 && output[0] == 0) { + size--; + output = output.slice(1); + } + } else { + while (size > 1 && output[size - 1] == 0) + size--; + } + + // If the MSB is set, we need to add a leading zero to indicate a positive number. + if ((output[0] & 0x80) != 0) { + TRY(write_length(size + 1)); + TRY(write_byte(0)); + } else { + TRY(write_length(size)); + } + return write_bytes(output.slice(0, size)); +} + +ErrorOr Encoder::write_printable_string(StringView string, Optional class_override, Optional kind_override) +{ + Utf8View view { string }; + if (!view.validate()) + return Error::from_string_literal("ASN1::Encoder: Invalid UTF-8 in printable string"); + + auto class_ = class_override.value_or(Class::Universal); + auto type = Type::Primitive; + auto kind = kind_override.value_or(Kind::PrintableString); + + TRY(write_tag(class_, type, kind)); + TRY(write_length(string.length())); + return write_bytes(string.bytes()); +} + +ErrorOr Encoder::write_octet_string(ReadonlyBytes bytes, Optional class_override, Optional kind_override) +{ + auto class_ = class_override.value_or(Class::Universal); + auto type = Type::Primitive; + auto kind = kind_override.value_or(Kind::OctetString); + + TRY(write_tag(class_, type, kind)); + TRY(write_length(bytes.size())); + return write_bytes(bytes); +} + +ErrorOr Encoder::write_null(Optional class_override, Optional kind_override) +{ + auto class_ = class_override.value_or(Class::Universal); + auto type = Type::Primitive; + auto kind = kind_override.value_or(Kind::Null); + + TRY(write_tag(class_, type, kind)); + TRY(write_length(0)); + return {}; +} + +ErrorOr Encoder::write_object_identifier(Span segments, Optional class_override, Optional kind_override) +{ + auto class_ = class_override.value_or(Class::Universal); + auto type = Type::Primitive; + auto kind = kind_override.value_or(Kind::ObjectIdentifier); + + if (segments.size() < 2) + return Error::from_string_literal("ASN1::Encoder: Object identifier must have at least two segments"); + + TRY(write_tag(class_, type, kind)); + size_t length = 1; + for (size_t i = 2; i < segments.size(); i++) { + auto segment = segments[i]; + if (segment < 0) + return Error::from_string_literal("ASN1::Encoder: Object identifier segments must be non-negative"); + + if (segment < 0x80) + length += 1; + else if (segment < 0x4000) + length += 2; + else if (segment < 0x200000) + length += 3; + else + length += 4; + } + + TRY(write_length(length)); + + auto first_byte = (segments[0] * 40) + segments[1]; + TRY(write_byte(first_byte)); + + for (size_t i = 2; i < segments.size(); i++) { + auto segment = segments[i]; + if (segment < 0x80) { + TRY(write_byte(segment)); + } else if (segment < 0x4000) { + TRY(write_byte((segment >> 7) | 0x80)); + TRY(write_byte(segment & 0x7f)); + } else if (segment < 0x200000) { + TRY(write_byte((segment >> 14) | 0x80)); + TRY(write_byte(((segment >> 7) & 0x7f) | 0x80)); + TRY(write_byte(segment & 0x7f)); + } else { + TRY(write_byte((segment >> 21) | 0x80)); + TRY(write_byte(((segment >> 14) & 0x7f) | 0x80)); + TRY(write_byte(((segment >> 7) & 0x7f) | 0x80)); + TRY(write_byte(segment & 0x7f)); + } + } + + return {}; +} + +ErrorOr Encoder::write_bit_string(BitStringView view, Optional class_override, Optional kind_override) +{ + auto class_ = class_override.value_or(Class::Universal); + auto type = Type::Primitive; + auto kind = kind_override.value_or(Kind::BitString); + + auto unused_bits = view.unused_bits(); + auto total_size_in_bits = view.byte_length() * 8 - unused_bits; + + TRY(write_tag(class_, type, kind)); + TRY(write_length(ceil_div(total_size_in_bits, 8ul) + 1)); + TRY(write_byte(unused_bits)); + return write_bytes(view.underlying_bytes()); +} + ErrorOr pretty_print(Decoder& decoder, Stream& stream, int indent) { while (!decoder.eof()) { @@ -244,7 +445,7 @@ ErrorOr pretty_print(Decoder& decoder, Stream& stream, int indent) builder.append(' '); builder.appendff("<{}> ", class_name(tag.class_)); if (tag.type == Type::Constructed) { - builder.appendff("[{}] {} ({})", type_name(tag.type), static_cast(tag.kind), kind_name(tag.kind)); + builder.appendff("[{}] {} ({})", type_name(tag.type), to_underlying(tag.kind), kind_name(tag.kind)); TRY(decoder.enter()); builder.append('\n'); @@ -257,7 +458,7 @@ ErrorOr pretty_print(Decoder& decoder, Stream& stream, int indent) continue; } else { if (tag.class_ != Class::Universal) - builder.appendff("[{}] {} {}", type_name(tag.type), static_cast(tag.kind), kind_name(tag.kind)); + builder.appendff("[{}] {} {}", type_name(tag.type), to_underlying(tag.kind), kind_name(tag.kind)); else builder.appendff("[{}] {}", type_name(tag.type), kind_name(tag.kind)); switch (tag.kind) { @@ -323,7 +524,7 @@ ErrorOr pretty_print(Decoder& decoder, Stream& stream, int indent) case Kind::Set: return Error::from_string_literal("ASN1::Decoder: Unexpected Primitive"); default: { - dbgln("PrettyPrint error: Unhandled kind {}", static_cast(tag.kind)); + dbgln("PrettyPrint error: Unhandled kind {}", to_underlying(tag.kind)); } } } diff --git a/Userland/Libraries/LibCrypto/ASN1/DER.h b/Userland/Libraries/LibCrypto/ASN1/DER.h index 56473207f7..aa2e5de325 100644 --- a/Userland/Libraries/LibCrypto/ASN1/DER.h +++ b/Userland/Libraries/LibCrypto/ASN1/DER.h @@ -29,13 +29,28 @@ public: return m_data; } - bool get(size_t index) + bool get(size_t index) const { if (index >= 8 * m_data.size() - m_unused_bits) return false; return 0 != (m_data[index / 8] & (1u << (7 - (index % 8)))); } + size_t unused_bits() const { return m_unused_bits; } + size_t byte_length() const { return m_data.size(); } + + ReadonlyBytes underlying_bytes() const { return m_data; } + + // FIXME: Improve me! I am naive! + bool operator==(BitStringView const& other) const + { + for (size_t bit_index = 0; bit_index < 8 * m_data.size() - m_unused_bits; ++bit_index) { + if (get(bit_index) != other.get(bit_index)) + return false; + } + return true; + } + private: ReadonlyBytes m_data; size_t m_unused_bits; @@ -186,34 +201,38 @@ private: { auto data = TRY(read_bytes(length)); - if (klass != Class::Universal) + if constexpr (IsSame) { + return data; + } else { + if (klass != Class::Universal) + return with_type_check(data); + + if (kind == Kind::Boolean) + return with_type_check(decode_boolean(data)); + + if (kind == Kind::Integer) + return with_type_check(decode_arbitrary_sized_integer(data)); + + if (kind == Kind::OctetString) + return with_type_check(decode_octet_string(data)); + + if (kind == Kind::Null) + return with_type_check(decode_null(data)); + + if (kind == Kind::ObjectIdentifier) + return with_type_check(decode_object_identifier(data)); + + if (kind == Kind::PrintableString || kind == Kind::IA5String || kind == Kind::UTCTime) + return with_type_check(decode_printable_string(data)); + + if (kind == Kind::Utf8String) + return with_type_check(StringView { data.data(), data.size() }); + + if (kind == Kind::BitString) + return with_type_check(decode_bit_string(data)); + return with_type_check(data); - - if (kind == Kind::Boolean) - return with_type_check(decode_boolean(data)); - - if (kind == Kind::Integer) - return with_type_check(decode_arbitrary_sized_integer(data)); - - if (kind == Kind::OctetString) - return with_type_check(decode_octet_string(data)); - - if (kind == Kind::Null) - return with_type_check(decode_null(data)); - - if (kind == Kind::ObjectIdentifier) - return with_type_check(decode_object_identifier(data)); - - if (kind == Kind::PrintableString || kind == Kind::IA5String || kind == Kind::UTCTime) - return with_type_check(decode_printable_string(data)); - - if (kind == Kind::Utf8String) - return with_type_check(StringView { data.data(), data.size() }); - - if (kind == Kind::BitString) - return with_type_check(decode_bit_string(data)); - - return with_type_check(data); + } } ErrorOr read_tag(); @@ -235,4 +254,83 @@ private: ErrorOr pretty_print(Decoder&, Stream&, int indent = 0); +class Encoder { +public: + Encoder() + { + m_buffer_stack.empend(); + } + + ReadonlyBytes active_bytes() const { return m_buffer_stack.last().bytes(); } + ByteBuffer finish() + { + VERIFY(m_buffer_stack.size() == 1); + return m_buffer_stack.take_last(); + } + + template + ErrorOr write(ValueType const& value, Optional class_override = {}, Optional kind_override = {}) + { + if constexpr (IsSame) { + return write_boolean(value, class_override, kind_override); + } else if constexpr (IsSame || (IsIntegral && IsUnsigned)) { + return write_arbitrary_sized_integer(value, class_override, kind_override); + } else if constexpr (IsOneOf) { + return write_printable_string(value, class_override, kind_override); + } else if constexpr (IsOneOf) { + return write_octet_string(value, class_override, kind_override); + } else if constexpr (IsSame) { + return write_null(class_override, kind_override); + } else if constexpr (IsOneOf, Span, Span>) { + return write_object_identifier(value, class_override, kind_override); + } else if constexpr (IsSame) { + return write_bit_string(value, class_override, kind_override); + } else { + dbgln("Unsupported type: {}", __PRETTY_FUNCTION__); + return Error::from_string_literal("ASN1::Encoder: Trying to encode a value of an unsupported type"); + } + } + + template + ErrorOr write_constructed(Class class_, Kind kind, Fn&& fn) + { + return write_constructed(bit_cast(class_), bit_cast(kind), forward(fn)); + } + + template + ErrorOr write_constructed(u8 class_, u8 kind, Fn&& fn) + { + m_buffer_stack.empend(); + using ResultType = decltype(fn()); + if constexpr (IsSpecializationOf) { + TRY(fn()); + } else { + fn(); + } + auto buffer = m_buffer_stack.take_last(); + + TRY(write_tag(bit_cast(class_), Type::Constructed, bit_cast(kind))); + TRY(write_length(buffer.size())); + TRY(write_bytes(buffer.bytes())); + + return {}; + } + +private: + ErrorOr write_tag(Class, Type, Kind); + ErrorOr write_length(size_t); + ErrorOr write_bytes(ReadonlyBytes); + ErrorOr write_byte(u8); + + ErrorOr write_boolean(bool, Optional, Optional); + ErrorOr write_arbitrary_sized_integer(UnsignedBigInteger const&, Optional, Optional); + ErrorOr write_printable_string(StringView, Optional, Optional); + ErrorOr write_octet_string(ReadonlyBytes, Optional, Optional); + ErrorOr write_null(Optional, Optional); + ErrorOr write_object_identifier(Span, Optional, Optional); + ErrorOr write_bit_string(BitStringView, Optional, Optional); + + Vector m_buffer_stack; +}; + }