From 630d5b3ffdc452c0c0aaa3a9a46fee892b4745ba Mon Sep 17 00:00:00 2001 From: Andreas Kling Date: Sat, 23 Nov 2019 16:43:21 +0100 Subject: [PATCH] LibIPC+AudioServer: Allow unsolicited server-to-client IPC messages Client-side connection objects must now provide both client and server endpoint types. When a message is received from the server side, we try to decode it using both endpoint types and then send it to the right place for handling. This now makes it possible for AudioServer to send unsolicited messages to its clients. This opens up a ton of possibilities :^) --- DevTools/IPCCompiler/main.cpp | 30 ++++++++++++++--- Libraries/LibAudio/AClientConnection.cpp | 6 +++- Libraries/LibAudio/AClientConnection.h | 7 +++- Libraries/LibCore/CoreIPCClient.h | 38 ++++++++++++++-------- Libraries/LibCore/CoreIPCServer.h | 25 ++++++++------ Libraries/LibIPC/IEndpoint.h | 1 + Libraries/LibIPC/IMessage.h | 1 + Servers/AudioServer/ASClientConnection.cpp | 8 ++--- Servers/AudioServer/ASClientConnection.h | 2 +- Servers/AudioServer/ASMixer.h | 8 +++-- Servers/AudioServer/AudioClient.ipc | 2 +- Servers/AudioServer/AudioServer.ipc | 2 +- Servers/AudioServer/Makefile | 7 ++-- 13 files changed, 95 insertions(+), 42 deletions(-) diff --git a/DevTools/IPCCompiler/main.cpp b/DevTools/IPCCompiler/main.cpp index 87d140bba3..27cc6e3395 100644 --- a/DevTools/IPCCompiler/main.cpp +++ b/DevTools/IPCCompiler/main.cpp @@ -27,6 +27,7 @@ struct Message { struct Endpoint { String name; + int magic; Vector messages; }; @@ -177,6 +178,13 @@ int main(int argc, char** argv) consume_whitespace(); endpoints.last().name = extract_while([](char ch) { return !isspace(ch); }); consume_whitespace(); + consume_specific('='); + consume_whitespace(); + auto magic_string = extract_while([](char ch) { return !isspace(ch) && ch != '{'; }); + bool ok; + endpoints.last().magic = magic_string.to_int(ok); + ASSERT(ok); + consume_whitespace(); consume_specific('{'); parse_messages(); consume_specific('}'); @@ -244,17 +252,20 @@ int main(int argc, char** argv) return builder.to_string(); }; - auto do_message = [&](const String& name, const Vector& parameters, String response_type = {}) { + auto do_message = [&](const String& name, const Vector& parameters, const String& response_type = {}) { dbg() << "class " << name << " final : public IMessage {"; dbg() << "public:"; if (!response_type.is_null()) dbg() << " typedef class " << response_type << " ResponseType;"; dbg() << " " << constructor_for_message(name, parameters); dbg() << " virtual ~" << name << "() override {}"; + dbg() << " virtual i32 endpoint_magic() const override { return " << endpoint.magic << "; }"; + dbg() << " static i32 static_endpoint_magic() { return " << endpoint.magic << "; }"; dbg() << " virtual i32 id() const override { return (int)MessageID::" << name << "; }"; dbg() << " static i32 static_message_id() { return (int)MessageID::" << name << "; }"; dbg() << " virtual String name() const override { return \"" << endpoint.name << "::" << name << "\"; }"; - dbg() << " static OwnPtr<" << name << "> decode(BufferStream& stream)"; + dbg() << " static String static_name() { return \"" << endpoint.name << "::" << name << "\"; }"; + dbg() << " static OwnPtr<" << name << "> decode(BufferStream& stream, size_t& size_in_bytes)"; dbg() << " {"; if (parameters.is_empty()) @@ -278,6 +289,7 @@ int main(int argc, char** argv) if (i != parameters.size() - 1) builder.append(", "); } + dbg() << " size_in_bytes = stream.offset();"; dbg() << " return make<" << name << ">(" << builder.to_string() << ");"; dbg() << " }"; dbg() << " virtual ByteBuffer encode() const override"; @@ -285,6 +297,7 @@ int main(int argc, char** argv) // FIXME: Support longer messages: dbg() << " auto buffer = ByteBuffer::create_uninitialized(1024);"; dbg() << " BufferStream stream(buffer);"; + dbg() << " stream << endpoint_magic();"; dbg() << " stream << (int)MessageID::" << name << ";"; for (auto& parameter : parameters) { dbg() << " stream << m_" << parameter.name << ";"; @@ -317,17 +330,24 @@ int main(int argc, char** argv) dbg() << "public:"; dbg() << " " << endpoint.name << "Endpoint() {}"; dbg() << " virtual ~" << endpoint.name << "Endpoint() override {}"; + dbg() << " static int static_magic() { return " << endpoint.magic << "; }"; + dbg() << " virtual int magic() const override { return " << endpoint.magic << "; }"; + dbg() << " static String static_name() { return \"" << endpoint.name << "\"; };"; dbg() << " virtual String name() const override { return \"" << endpoint.name << "\"; };"; - dbg() << " static OwnPtr decode_message(const ByteBuffer& buffer)"; + dbg() << " static OwnPtr decode_message(const ByteBuffer& buffer, size_t& size_in_bytes)"; dbg() << " {"; dbg() << " BufferStream stream(const_cast(buffer));"; + dbg() << " i32 message_endpoint_magic = 0;"; + dbg() << " stream >> message_endpoint_magic;"; + dbg() << " if (message_endpoint_magic != " << endpoint.magic << ")"; + dbg() << " return nullptr;"; dbg() << " i32 message_id = 0;"; dbg() << " stream >> message_id;"; dbg() << " switch (message_id) {"; for (auto& message : endpoint.messages) { auto do_decode_message = [&](const String& name) { dbg() << " case (int)" << endpoint.name << "::MessageID::" << name << ":"; - dbg() << " return " << endpoint.name << "::" << name << "::decode(stream);"; + dbg() << " return " << endpoint.name << "::" << name << "::decode(stream, size_in_bytes);"; }; do_decode_message(message.name); if (message.is_synchronous) @@ -383,7 +403,7 @@ int main(int argc, char** argv) #ifdef DEBUG for (auto& endpoint : endpoints) { - dbg() << "Endpoint: '" << endpoint.name << "'"; + dbg() << "Endpoint: '" << endpoint.name << "' (magic: " << endpoint.magic << ")"; for (auto& message : endpoint.messages) { dbg() << " Message: '" << message.name << "'"; dbg() << " Sync: " << message.is_synchronous; diff --git a/Libraries/LibAudio/AClientConnection.cpp b/Libraries/LibAudio/AClientConnection.cpp index 6ca9c2e2c0..ea18bd8115 100644 --- a/Libraries/LibAudio/AClientConnection.cpp +++ b/Libraries/LibAudio/AClientConnection.cpp @@ -3,7 +3,7 @@ #include AClientConnection::AClientConnection() - : ConnectionNG("/tmp/asportal") + : ConnectionNG(*this, "/tmp/asportal") { } @@ -76,3 +76,7 @@ int AClientConnection::get_playing_buffer() { return send_sync()->buffer_id(); } + +void AClientConnection::handle(const AudioClient::FinishedPlayingBuffer&) +{ +} diff --git a/Libraries/LibAudio/AClientConnection.h b/Libraries/LibAudio/AClientConnection.h index 3d2089a57d..87ec1b4eaa 100644 --- a/Libraries/LibAudio/AClientConnection.h +++ b/Libraries/LibAudio/AClientConnection.h @@ -1,11 +1,13 @@ #pragma once +#include #include #include class ABuffer; -class AClientConnection : public IPC::Client::ConnectionNG { +class AClientConnection : public IPC::Client::ConnectionNG + , public AudioClientEndpoint { C_OBJECT(AClientConnection) public: AClientConnection(); @@ -26,4 +28,7 @@ public: void set_paused(bool paused); void clear_buffer(bool paused = false); + +private: + virtual void handle(const AudioClient::FinishedPlayingBuffer&) override; }; diff --git a/Libraries/LibCore/CoreIPCClient.h b/Libraries/LibCore/CoreIPCClient.h index d21cc9ba9e..0aea1195ae 100644 --- a/Libraries/LibCore/CoreIPCClient.h +++ b/Libraries/LibCore/CoreIPCClient.h @@ -246,11 +246,12 @@ namespace Client { int m_my_client_id { -1 }; }; - template + template class ConnectionNG : public CObject { public: - ConnectionNG(const StringView& address) - : m_connection(CLocalSocket::construct(this)) + ConnectionNG(LocalEndpoint& local_endpoint, const StringView& address) + : m_local_endpoint(local_endpoint) + , m_connection(CLocalSocket::construct(this)) , m_notifier(CNotifier::construct(m_connection->fd(), CNotifier::Read, this)) { // We want to rate-limit our clients @@ -312,8 +313,7 @@ namespace Client { } ASSERT(rc > 0); ASSERT(FD_ISSET(m_connection->fd(), &rfds)); - bool success = drain_messages_from_server(); - if (!success) + if (!drain_messages_from_server()) return nullptr; for (ssize_t i = 0; i < m_unprocessed_messages.size(); ++i) { if (m_unprocessed_messages[i]->id() == MessageType::static_message_id()) { @@ -358,30 +358,42 @@ namespace Client { private: bool drain_messages_from_server() { + Vector bytes; for (;;) { u8 buffer[4096]; ssize_t nread = recv(m_connection->fd(), buffer, sizeof(buffer), MSG_DONTWAIT); if (nread < 0) { - if (errno == EAGAIN) { - return true; - } + if (errno == EAGAIN) + break; perror("read"); exit(1); return false; } if (nread == 0) { dbg() << "EOF on IPC fd"; + // FIXME: Dying is definitely not always appropriate! exit(1); return false; } - - auto message = Endpoint::decode_message(ByteBuffer::wrap(buffer, sizeof(buffer))); - ASSERT(message); - - m_unprocessed_messages.append(move(message)); + bytes.append(buffer, nread); } + + size_t decoded_bytes = 0; + for (size_t index = 0; index < (size_t)bytes.size(); index += decoded_bytes) { + auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index); + if (auto message = LocalEndpoint::decode_message(remaining_bytes, decoded_bytes)) { + m_local_endpoint.handle(*message); + } else if (auto message = PeerEndpoint::decode_message(remaining_bytes, decoded_bytes)) { + m_unprocessed_messages.append(move(message)); + } else { + ASSERT_NOT_REACHED(); + } + ASSERT(decoded_bytes); + } + return true; } + LocalEndpoint& m_local_endpoint; RefPtr m_connection; RefPtr m_notifier; Vector> m_unprocessed_messages; diff --git a/Libraries/LibCore/CoreIPCServer.h b/Libraries/LibCore/CoreIPCServer.h index e4b7a0ae5b..af65613fe4 100644 --- a/Libraries/LibCore/CoreIPCServer.h +++ b/Libraries/LibCore/CoreIPCServer.h @@ -256,7 +256,7 @@ namespace Server { , m_client_id(client_id) { add_child(socket); - m_socket->on_ready_to_read = [this] { drain_client(); }; + m_socket->on_ready_to_read = [this] { drain_messages_from_client(); }; } virtual ~ConnectionNG() override @@ -287,15 +287,16 @@ namespace Server { ASSERT(nwritten == buffer.size()); } - void drain_client() + void drain_messages_from_client() { - unsigned messages_received = 0; + Vector bytes; for (;;) { u8 buffer[4096]; ssize_t nread = recv(m_socket->fd(), buffer, sizeof(buffer), MSG_DONTWAIT); if (nread == 0 || (nread == -1 && errno == EAGAIN)) { - if (!messages_received) { + if (bytes.is_empty()) { CEventLoop::current().post_event(*this, make(client_id())); + return; } break; } @@ -303,17 +304,21 @@ namespace Server { perror("recv"); ASSERT_NOT_REACHED(); } - auto message = m_endpoint.decode_message(ByteBuffer::wrap(buffer, nread)); + bytes.append(buffer, nread); + } + + size_t decoded_bytes = 0; + for (size_t index = 0; index < (size_t)bytes.size(); index += decoded_bytes) { + auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index); + auto message = Endpoint::decode_message(remaining_bytes, decoded_bytes); if (!message) { - dbg() << "drain_client: Endpoint didn't recognize message"; + dbg() << "drain_messages_from_client: Endpoint didn't recognize message"; did_misbehave(); return; } - ++messages_received; - - auto response = m_endpoint.handle(*message); - if (response) + if (auto response = m_endpoint.handle(*message)) post_message(*response); + ASSERT(decoded_bytes); } } diff --git a/Libraries/LibIPC/IEndpoint.h b/Libraries/LibIPC/IEndpoint.h index e0098f773d..1468f68c32 100644 --- a/Libraries/LibIPC/IEndpoint.h +++ b/Libraries/LibIPC/IEndpoint.h @@ -13,6 +13,7 @@ class IEndpoint { public: virtual ~IEndpoint(); + virtual int magic() const = 0; virtual String name() const = 0; virtual OwnPtr handle(const IMessage&) = 0; diff --git a/Libraries/LibIPC/IMessage.h b/Libraries/LibIPC/IMessage.h index 6a0a81b1e1..813b4f528e 100644 --- a/Libraries/LibIPC/IMessage.h +++ b/Libraries/LibIPC/IMessage.h @@ -7,6 +7,7 @@ class IMessage { public: virtual ~IMessage(); + virtual int endpoint_magic() const = 0; virtual int id() const = 0; virtual String name() const = 0; virtual ByteBuffer encode() const = 0; diff --git a/Servers/AudioServer/ASClientConnection.cpp b/Servers/AudioServer/ASClientConnection.cpp index 47595c458d..e6704db19c 100644 --- a/Servers/AudioServer/ASClientConnection.cpp +++ b/Servers/AudioServer/ASClientConnection.cpp @@ -1,10 +1,9 @@ #include "ASClientConnection.h" #include "ASMixer.h" - +#include "AudioClientEndpoint.h" #include #include #include - #include #include #include @@ -30,10 +29,9 @@ void ASClientConnection::die() s_connections.remove(client_id()); } -void ASClientConnection::did_finish_playing_buffer(Badge, int buffer_id) +void ASClientConnection::did_finish_playing_buffer(Badge, int buffer_id) { - (void)buffer_id; - //post_message(AudioClient::FinishedPlayingBuffer(buffer_id)); + post_message(AudioClient::FinishedPlayingBuffer(buffer_id)); } OwnPtr ASClientConnection::handle(const AudioServer::Greet& message) diff --git a/Servers/AudioServer/ASClientConnection.h b/Servers/AudioServer/ASClientConnection.h index b23886dcba..4147a231d0 100644 --- a/Servers/AudioServer/ASClientConnection.h +++ b/Servers/AudioServer/ASClientConnection.h @@ -13,7 +13,7 @@ class ASClientConnection final : public IPC::Server::ConnectionNG, int buffer_id); + void did_finish_playing_buffer(Badge, int buffer_id); virtual void die() override; diff --git a/Servers/AudioServer/ASMixer.h b/Servers/AudioServer/ASMixer.h index 960dce9996..5de1b624fc 100644 --- a/Servers/AudioServer/ASMixer.h +++ b/Servers/AudioServer/ASMixer.h @@ -1,5 +1,6 @@ #pragma once +#include "ASClientConnection.h" #include #include #include @@ -36,6 +37,7 @@ public: ++m_played_samples; if (m_position >= m_current->sample_count()) { + m_client->did_finish_playing_buffer({}, m_current->shared_buffer_id()); m_current = nullptr; m_position = 0; } @@ -61,8 +63,10 @@ public: int get_remaining_samples() const { return m_remaining_samples; } int get_played_samples() const { return m_played_samples; } - int get_playing_buffer() const { - if(m_current) return m_current->shared_buffer_id(); + int get_playing_buffer() const + { + if (m_current) + return m_current->shared_buffer_id(); return -1; } diff --git a/Servers/AudioServer/AudioClient.ipc b/Servers/AudioServer/AudioClient.ipc index 52a87281b0..ed9ba77b3a 100644 --- a/Servers/AudioServer/AudioClient.ipc +++ b/Servers/AudioServer/AudioClient.ipc @@ -1,4 +1,4 @@ -endpoint AudioClient +endpoint AudioClient = 82 { FinishedPlayingBuffer(i32 buffer_id) =| } diff --git a/Servers/AudioServer/AudioServer.ipc b/Servers/AudioServer/AudioServer.ipc index 1bdc7c8091..b48cbc6c06 100644 --- a/Servers/AudioServer/AudioServer.ipc +++ b/Servers/AudioServer/AudioServer.ipc @@ -1,4 +1,4 @@ -endpoint AudioServer +endpoint AudioServer = 85 { // Basic protocol Greet(i32 client_pid) => (i32 server_pid, i32 client_id) diff --git a/Servers/AudioServer/Makefile b/Servers/AudioServer/Makefile index 9cf71eeb0f..ebdf331a8b 100644 --- a/Servers/AudioServer/Makefile +++ b/Servers/AudioServer/Makefile @@ -13,11 +13,14 @@ DEFINES += -DUSERLAND all: $(APP) -*.cpp: AudioServerEndpoint.h +*.cpp: AudioServerEndpoint.h AudioClientEndpoint.h AudioServerEndpoint.h: AudioServer.ipc @echo "IPC $<"; $(IPCCOMPILER) $< > $@ +AudioClientEndpoint.h: AudioClient.ipc + @echo "IPC $<"; $(IPCCOMPILER) $< > $@ + $(APP): $(OBJS) $(LD) -o $(APP) $(LDFLAGS) $(OBJS) -lc -lcore -lipc -lthread -lpthread @@ -27,5 +30,5 @@ $(APP): $(OBJS) -include $(OBJS:%.o=%.d) clean: - @echo "CLEAN"; rm -f $(APP) $(OBJS) *.d AudioServerEndpoint.h + @echo "CLEAN"; rm -f $(APP) $(OBJS) *.d AudioServerEndpoint.h AudioClientEndpoint.h