From fa2e3e2be42277b3c565c308e4aaef0972301563 Mon Sep 17 00:00:00 2001 From: Sergey Bugaev Date: Sat, 21 Nov 2020 21:55:00 +0300 Subject: [PATCH] LibIPC: Prepend each message with its size This makes it much simpler to determine when we've read a complete message, and will make it possible to integrate recvfd() in the future commit. --- AK/SourceGenerator.h | 2 +- DevTools/IPCCompiler/main.cpp | 9 +++---- Libraries/LibIPC/Connection.h | 47 ++++++++++++++++++++++------------- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/AK/SourceGenerator.h b/AK/SourceGenerator.h index b5d4b29b82..15c0aaae42 100644 --- a/AK/SourceGenerator.h +++ b/AK/SourceGenerator.h @@ -66,7 +66,7 @@ public: GenericLexer lexer { pattern }; while (!lexer.is_eof()) { - // FIXME: It is a bit inconvinient, that 'consume_until' also consumes the 'stop' character, this makes + // FIXME: It is a bit inconvenient, that 'consume_until' also consumes the 'stop' character, this makes // the method less generic because there is no way to check if the 'stop' character ever appeared. const auto consume_until_without_consuming_stop_character = [&](char stop) { return lexer.consume_while([&](char ch) { return ch != stop; }); diff --git a/DevTools/IPCCompiler/main.cpp b/DevTools/IPCCompiler/main.cpp index fa55e94106..805bd60a6b 100644 --- a/DevTools/IPCCompiler/main.cpp +++ b/DevTools/IPCCompiler/main.cpp @@ -318,9 +318,9 @@ public: static i32 static_message_id() { return (int)MessageID::@message.name@; } virtual const char* message_name() const override { return "@endpoint.name@::@message.name@"; } - static OwnPtr<@message.name@> decode(InputMemoryStream& stream, size_t& size_in_bytes) + static OwnPtr<@message.name@> decode(InputMemoryStream& stream) { - IPC::Decoder decoder {stream}; + IPC::Decoder decoder { stream }; )~~~"); for (auto& parameter : parameters) { @@ -359,7 +359,6 @@ public: message_generator.set("message.constructor_call_parameters", builder.build()); message_generator.append(R"~~~( - size_in_bytes = stream.offset(); return make<@message.name@>(@message.constructor_call_parameters@); } )~~~"); @@ -437,7 +436,7 @@ public: static String static_name() { return "@endpoint.name@"; } virtual String name() const override { return "@endpoint.name@"; } - static OwnPtr decode_message(const ByteBuffer& buffer, size_t& size_in_bytes) + static OwnPtr decode_message(const ByteBuffer& buffer) { InputMemoryStream stream { buffer }; i32 message_endpoint_magic = 0; @@ -489,7 +488,7 @@ public: message_generator.append(R"~~~( case (int)Messages::@endpoint.name@::MessageID::@message.name@: - message = Messages::@endpoint.name@::@message.name@::decode(stream, size_in_bytes); + message = Messages::@endpoint.name@::@message.name@::decode(stream); break; )~~~"); }; diff --git a/Libraries/LibIPC/Connection.h b/Libraries/LibIPC/Connection.h index 4db21a6dc7..781a2d4c4f 100644 --- a/Libraries/LibIPC/Connection.h +++ b/Libraries/LibIPC/Connection.h @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -75,10 +76,13 @@ public: return; auto buffer = message.encode(); + // Prepend the message size. + uint32_t message_size = buffer.size(); + buffer.prepend(reinterpret_cast(&message_size), sizeof(message_size)); - auto bytes_remaining = buffer.size(); - while (bytes_remaining) { - auto nwritten = write(m_socket->fd(), buffer.data(), buffer.size()); + size_t total_nwritten = 0; + while (total_nwritten < buffer.size()) { + auto nwritten = write(m_socket->fd(), buffer.data() + total_nwritten, buffer.size() - total_nwritten); if (nwritten < 0) { switch (errno) { case EPIPE: @@ -95,7 +99,7 @@ public: return; } } - bytes_remaining -= nwritten; + total_nwritten += nwritten; } m_responsiveness_timer->start(); @@ -190,25 +194,34 @@ protected: did_become_responsive(); } - size_t decoded_bytes = 0; - for (size_t index = 0; index < bytes.size(); index += decoded_bytes) { + size_t index = 0; + uint32_t message_size = 0; + for (; index + sizeof(message_size) < bytes.size(); index += message_size) { + message_size = *reinterpret_cast(bytes.data() + index); + if (message_size == 0 || bytes.size() - index - sizeof(uint32_t) < message_size) + break; + index += sizeof(message_size); auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index); - if (auto message = LocalEndpoint::decode_message(remaining_bytes, decoded_bytes)) { + if (auto message = LocalEndpoint::decode_message(remaining_bytes)) { m_unprocessed_messages.append(message.release_nonnull()); - } else if (auto message = PeerEndpoint::decode_message(remaining_bytes, decoded_bytes)) { + } else if (auto message = PeerEndpoint::decode_message(remaining_bytes)) { m_unprocessed_messages.append(message.release_nonnull()); } else { - // Sometimes we might receive a partial message. That's okay, just stash away - // the unprocessed bytes and we'll prepend them to the next incoming message - // in the next run of this function. - if (!m_unprocessed_bytes.is_empty()) { - dbg() << *this << "::drain_messages_from_peer: Already have unprocessed bytes"; - shutdown(); - } - m_unprocessed_bytes = remaining_bytes.isolated_copy(); + dbgln("Failed to parse a message"); break; } - ASSERT(decoded_bytes); + } + + if (index < bytes.size()) { + // Sometimes we might receive a partial message. That's okay, just stash away + // the unprocessed bytes and we'll prepend them to the next incoming message + // in the next run of this function. + auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index); + if (!m_unprocessed_bytes.is_empty()) { + dbg() << *this << "::drain_messages_from_peer: Already have unprocessed bytes"; + shutdown(); + } + m_unprocessed_bytes = remaining_bytes.isolated_copy(); } if (!m_unprocessed_messages.is_empty()) {