From 87c64834caf31347723ce6eea8135e60e3727dfa Mon Sep 17 00:00:00 2001 From: Tim Schumacher Date: Mon, 9 Jan 2023 18:57:41 +0100 Subject: [PATCH] LibDNS: Use `AllocatingMemoryStream` in DNS package construction --- Userland/Libraries/LibDNS/CMakeLists.txt | 2 +- Userland/Libraries/LibDNS/Name.cpp | 12 +++---- Userland/Libraries/LibDNS/Name.h | 4 +-- Userland/Libraries/LibDNS/Packet.cpp | 33 ++++++++++--------- Userland/Libraries/LibDNS/Packet.h | 2 +- Userland/Services/LookupServer/DNSServer.cpp | 2 +- .../Services/LookupServer/LookupServer.cpp | 2 +- .../Services/LookupServer/MulticastDNS.cpp | 2 +- 8 files changed, 31 insertions(+), 28 deletions(-) diff --git a/Userland/Libraries/LibDNS/CMakeLists.txt b/Userland/Libraries/LibDNS/CMakeLists.txt index 327558725b..b449de5fd6 100644 --- a/Userland/Libraries/LibDNS/CMakeLists.txt +++ b/Userland/Libraries/LibDNS/CMakeLists.txt @@ -5,4 +5,4 @@ set(SOURCES ) serenity_lib(LibDNS dns) -target_link_libraries(LibDNS PRIVATE LibIPC) +target_link_libraries(LibDNS PRIVATE LibCore LibIPC) diff --git a/Userland/Libraries/LibDNS/Name.cpp b/Userland/Libraries/LibDNS/Name.cpp index fcc0c41ff0..3d351a28c3 100644 --- a/Userland/Libraries/LibDNS/Name.cpp +++ b/Userland/Libraries/LibDNS/Name.cpp @@ -75,15 +75,15 @@ void Name::randomize_case() m_name = builder.to_deprecated_string(); } -OutputStream& operator<<(OutputStream& stream, Name const& name) +ErrorOr Name::write_to_stream(Core::Stream::Stream& stream) const { - auto parts = name.as_string().split_view('.'); + auto parts = as_string().split_view('.'); for (auto& part : parts) { - stream << (u8)part.length(); - stream << part.bytes(); + TRY(stream.write_trivial_value(part.length())); + TRY(stream.write_entire_buffer(part.bytes())); } - stream << '\0'; - return stream; + TRY(stream.write_trivial_value('\0')); + return {}; } unsigned Name::Traits::hash(Name const& name) diff --git a/Userland/Libraries/LibDNS/Name.h b/Userland/Libraries/LibDNS/Name.h index 5dfd97862b..0ad3c3a54f 100644 --- a/Userland/Libraries/LibDNS/Name.h +++ b/Userland/Libraries/LibDNS/Name.h @@ -9,6 +9,7 @@ #include #include +#include namespace DNS { @@ -21,6 +22,7 @@ public: size_t serialized_size() const; DeprecatedString const& as_string() const { return m_name; } + ErrorOr write_to_stream(Core::Stream::Stream&) const; void randomize_case(); @@ -36,8 +38,6 @@ private: DeprecatedString m_name; }; -OutputStream& operator<<(OutputStream& stream, Name const&); - } template<> diff --git a/Userland/Libraries/LibDNS/Packet.cpp b/Userland/Libraries/LibDNS/Packet.cpp index 9d0c241441..58a8f81fef 100644 --- a/Userland/Libraries/LibDNS/Packet.cpp +++ b/Userland/Libraries/LibDNS/Packet.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace DNS { @@ -29,7 +30,7 @@ void Packet::add_answer(Answer const& answer) VERIFY(m_answers.size() <= UINT16_MAX); } -ByteBuffer Packet::to_byte_buffer() const +ErrorOr Packet::to_byte_buffer() const { PacketHeader header; header.set_id(m_id); @@ -48,30 +49,32 @@ ByteBuffer Packet::to_byte_buffer() const header.set_question_count(m_questions.size()); header.set_answer_count(m_answers.size()); - DuplexMemoryStream stream; + Core::Stream::AllocatingMemoryStream stream; - stream << ReadonlyBytes { &header, sizeof(header) }; + TRY(stream.write_trivial_value(header)); for (auto& question : m_questions) { - stream << question.name(); - stream << htons((u16)question.record_type()); - stream << htons(question.raw_class_code()); + TRY(question.name().write_to_stream(stream)); + TRY(stream.write_trivial_value(htons((u16)question.record_type()))); + TRY(stream.write_trivial_value(htons(question.raw_class_code()))); } for (auto& answer : m_answers) { - stream << answer.name(); - stream << htons((u16)answer.type()); - stream << htons(answer.raw_class_code()); - stream << htonl(answer.ttl()); + TRY(answer.name().write_to_stream(stream)); + TRY(stream.write_trivial_value(htons((u16)answer.type()))); + TRY(stream.write_trivial_value(htons(answer.raw_class_code()))); + TRY(stream.write_trivial_value(htonl(answer.ttl()))); if (answer.type() == RecordType::PTR) { Name name { answer.record_data() }; - stream << htons(name.serialized_size()); - stream << name; + TRY(stream.write_trivial_value(htons(name.serialized_size()))); + TRY(name.write_to_stream(stream)); } else { - stream << htons(answer.record_data().length()); - stream << answer.record_data().bytes(); + TRY(stream.write_trivial_value(htons(answer.record_data().length()))); + TRY(stream.write_entire_buffer(answer.record_data().bytes())); } } - return stream.copy_into_contiguous_buffer(); + auto buffer = TRY(ByteBuffer::create_uninitialized(stream.used_buffer_size())); + TRY(stream.read_entire_buffer(buffer)); + return buffer; } class [[gnu::packed]] DNSRecordWithoutName { diff --git a/Userland/Libraries/LibDNS/Packet.h b/Userland/Libraries/LibDNS/Packet.h index 871c60173d..7a1ab8262b 100644 --- a/Userland/Libraries/LibDNS/Packet.h +++ b/Userland/Libraries/LibDNS/Packet.h @@ -25,7 +25,7 @@ public: Packet() = default; static Optional from_raw_packet(u8 const*, size_t); - ByteBuffer to_byte_buffer() const; + ErrorOr to_byte_buffer() const; bool is_query() const { return !m_query_or_response; } bool is_response() const { return m_query_or_response; } diff --git a/Userland/Services/LookupServer/DNSServer.cpp b/Userland/Services/LookupServer/DNSServer.cpp index 0ef59a97b9..970d339387 100644 --- a/Userland/Services/LookupServer/DNSServer.cpp +++ b/Userland/Services/LookupServer/DNSServer.cpp @@ -62,7 +62,7 @@ ErrorOr DNSServer::handle_client() else response.set_code(Packet::Code::NOERROR); - buffer = response.to_byte_buffer(); + buffer = TRY(response.to_byte_buffer()); TRY(send(buffer, client_address)); return {}; diff --git a/Userland/Services/LookupServer/LookupServer.cpp b/Userland/Services/LookupServer/LookupServer.cpp index 5b77ca408a..42b770ff67 100644 --- a/Userland/Services/LookupServer/LookupServer.cpp +++ b/Userland/Services/LookupServer/LookupServer.cpp @@ -234,7 +234,7 @@ ErrorOr> LookupServer::lookup(Name const& name, DeprecatedString name_in_question.randomize_case(); request.add_question({ name_in_question, record_type, RecordClass::IN, false }); - auto buffer = request.to_byte_buffer(); + auto buffer = TRY(request.to_byte_buffer()); auto udp_socket = TRY(Core::Stream::UDPSocket::connect(nameserver, 53, Time::from_seconds(1))); TRY(udp_socket->set_blocking(true)); diff --git a/Userland/Services/LookupServer/MulticastDNS.cpp b/Userland/Services/LookupServer/MulticastDNS.cpp index adc912a140..3061edea5b 100644 --- a/Userland/Services/LookupServer/MulticastDNS.cpp +++ b/Userland/Services/LookupServer/MulticastDNS.cpp @@ -110,7 +110,7 @@ void MulticastDNS::announce() ErrorOr MulticastDNS::emit_packet(Packet const& packet, sockaddr_in const* destination) { - auto buffer = packet.to_byte_buffer(); + auto buffer = TRY(packet.to_byte_buffer()); if (!destination) destination = &mdns_addr;