diff --git a/DevTools/IPCCompiler/main.cpp b/DevTools/IPCCompiler/main.cpp index 87d140bba3..27cc6e3395 100644 --- a/DevTools/IPCCompiler/main.cpp +++ b/DevTools/IPCCompiler/main.cpp @@ -27,6 +27,7 @@ struct Message { struct Endpoint { String name; + int magic; Vector messages; }; @@ -177,6 +178,13 @@ int main(int argc, char** argv) consume_whitespace(); endpoints.last().name = extract_while([](char ch) { return !isspace(ch); }); consume_whitespace(); + consume_specific('='); + consume_whitespace(); + auto magic_string = extract_while([](char ch) { return !isspace(ch) && ch != '{'; }); + bool ok; + endpoints.last().magic = magic_string.to_int(ok); + ASSERT(ok); + consume_whitespace(); consume_specific('{'); parse_messages(); consume_specific('}'); @@ -244,17 +252,20 @@ int main(int argc, char** argv) return builder.to_string(); }; - auto do_message = [&](const String& name, const Vector& parameters, String response_type = {}) { + auto do_message = [&](const String& name, const Vector& parameters, const String& response_type = {}) { dbg() << "class " << name << " final : public IMessage {"; dbg() << "public:"; if (!response_type.is_null()) dbg() << " typedef class " << response_type << " ResponseType;"; dbg() << " " << constructor_for_message(name, parameters); dbg() << " virtual ~" << name << "() override {}"; + dbg() << " virtual i32 endpoint_magic() const override { return " << endpoint.magic << "; }"; + dbg() << " static i32 static_endpoint_magic() { return " << endpoint.magic << "; }"; dbg() << " virtual i32 id() const override { return (int)MessageID::" << name << "; }"; dbg() << " static i32 static_message_id() { return (int)MessageID::" << name << "; }"; dbg() << " virtual String name() const override { return \"" << endpoint.name << "::" << name << "\"; }"; - dbg() << " static OwnPtr<" << name << "> decode(BufferStream& stream)"; + dbg() << " static String static_name() { return \"" << endpoint.name << "::" << name << "\"; }"; + dbg() << " static OwnPtr<" << name << "> decode(BufferStream& stream, size_t& size_in_bytes)"; dbg() << " {"; if (parameters.is_empty()) @@ -278,6 +289,7 @@ int main(int argc, char** argv) if (i != parameters.size() - 1) builder.append(", "); } + dbg() << " size_in_bytes = stream.offset();"; dbg() << " return make<" << name << ">(" << builder.to_string() << ");"; dbg() << " }"; dbg() << " virtual ByteBuffer encode() const override"; @@ -285,6 +297,7 @@ int main(int argc, char** argv) // FIXME: Support longer messages: dbg() << " auto buffer = ByteBuffer::create_uninitialized(1024);"; dbg() << " BufferStream stream(buffer);"; + dbg() << " stream << endpoint_magic();"; dbg() << " stream << (int)MessageID::" << name << ";"; for (auto& parameter : parameters) { dbg() << " stream << m_" << parameter.name << ";"; @@ -317,17 +330,24 @@ int main(int argc, char** argv) dbg() << "public:"; dbg() << " " << endpoint.name << "Endpoint() {}"; dbg() << " virtual ~" << endpoint.name << "Endpoint() override {}"; + dbg() << " static int static_magic() { return " << endpoint.magic << "; }"; + dbg() << " virtual int magic() const override { return " << endpoint.magic << "; }"; + dbg() << " static String static_name() { return \"" << endpoint.name << "\"; };"; dbg() << " virtual String name() const override { return \"" << endpoint.name << "\"; };"; - dbg() << " static OwnPtr decode_message(const ByteBuffer& buffer)"; + dbg() << " static OwnPtr decode_message(const ByteBuffer& buffer, size_t& size_in_bytes)"; dbg() << " {"; dbg() << " BufferStream stream(const_cast(buffer));"; + dbg() << " i32 message_endpoint_magic = 0;"; + dbg() << " stream >> message_endpoint_magic;"; + dbg() << " if (message_endpoint_magic != " << endpoint.magic << ")"; + dbg() << " return nullptr;"; dbg() << " i32 message_id = 0;"; dbg() << " stream >> message_id;"; dbg() << " switch (message_id) {"; for (auto& message : endpoint.messages) { auto do_decode_message = [&](const String& name) { dbg() << " case (int)" << endpoint.name << "::MessageID::" << name << ":"; - dbg() << " return " << endpoint.name << "::" << name << "::decode(stream);"; + dbg() << " return " << endpoint.name << "::" << name << "::decode(stream, size_in_bytes);"; }; do_decode_message(message.name); if (message.is_synchronous) @@ -383,7 +403,7 @@ int main(int argc, char** argv) #ifdef DEBUG for (auto& endpoint : endpoints) { - dbg() << "Endpoint: '" << endpoint.name << "'"; + dbg() << "Endpoint: '" << endpoint.name << "' (magic: " << endpoint.magic << ")"; for (auto& message : endpoint.messages) { dbg() << " Message: '" << message.name << "'"; dbg() << " Sync: " << message.is_synchronous; diff --git a/Libraries/LibAudio/AClientConnection.cpp b/Libraries/LibAudio/AClientConnection.cpp index 6ca9c2e2c0..ea18bd8115 100644 --- a/Libraries/LibAudio/AClientConnection.cpp +++ b/Libraries/LibAudio/AClientConnection.cpp @@ -3,7 +3,7 @@ #include AClientConnection::AClientConnection() - : ConnectionNG("/tmp/asportal") + : ConnectionNG(*this, "/tmp/asportal") { } @@ -76,3 +76,7 @@ int AClientConnection::get_playing_buffer() { return send_sync()->buffer_id(); } + +void AClientConnection::handle(const AudioClient::FinishedPlayingBuffer&) +{ +} diff --git a/Libraries/LibAudio/AClientConnection.h b/Libraries/LibAudio/AClientConnection.h index 3d2089a57d..87ec1b4eaa 100644 --- a/Libraries/LibAudio/AClientConnection.h +++ b/Libraries/LibAudio/AClientConnection.h @@ -1,11 +1,13 @@ #pragma once +#include #include #include class ABuffer; -class AClientConnection : public IPC::Client::ConnectionNG { +class AClientConnection : public IPC::Client::ConnectionNG + , public AudioClientEndpoint { C_OBJECT(AClientConnection) public: AClientConnection(); @@ -26,4 +28,7 @@ public: void set_paused(bool paused); void clear_buffer(bool paused = false); + +private: + virtual void handle(const AudioClient::FinishedPlayingBuffer&) override; }; diff --git a/Libraries/LibCore/CoreIPCClient.h b/Libraries/LibCore/CoreIPCClient.h index d21cc9ba9e..0aea1195ae 100644 --- a/Libraries/LibCore/CoreIPCClient.h +++ b/Libraries/LibCore/CoreIPCClient.h @@ -246,11 +246,12 @@ namespace Client { int m_my_client_id { -1 }; }; - template + template class ConnectionNG : public CObject { public: - ConnectionNG(const StringView& address) - : m_connection(CLocalSocket::construct(this)) + ConnectionNG(LocalEndpoint& local_endpoint, const StringView& address) + : m_local_endpoint(local_endpoint) + , m_connection(CLocalSocket::construct(this)) , m_notifier(CNotifier::construct(m_connection->fd(), CNotifier::Read, this)) { // We want to rate-limit our clients @@ -312,8 +313,7 @@ namespace Client { } ASSERT(rc > 0); ASSERT(FD_ISSET(m_connection->fd(), &rfds)); - bool success = drain_messages_from_server(); - if (!success) + if (!drain_messages_from_server()) return nullptr; for (ssize_t i = 0; i < m_unprocessed_messages.size(); ++i) { if (m_unprocessed_messages[i]->id() == MessageType::static_message_id()) { @@ -358,30 +358,42 @@ namespace Client { private: bool drain_messages_from_server() { + Vector bytes; for (;;) { u8 buffer[4096]; ssize_t nread = recv(m_connection->fd(), buffer, sizeof(buffer), MSG_DONTWAIT); if (nread < 0) { - if (errno == EAGAIN) { - return true; - } + if (errno == EAGAIN) + break; perror("read"); exit(1); return false; } if (nread == 0) { dbg() << "EOF on IPC fd"; + // FIXME: Dying is definitely not always appropriate! exit(1); return false; } - - auto message = Endpoint::decode_message(ByteBuffer::wrap(buffer, sizeof(buffer))); - ASSERT(message); - - m_unprocessed_messages.append(move(message)); + bytes.append(buffer, nread); } + + size_t decoded_bytes = 0; + for (size_t index = 0; index < (size_t)bytes.size(); index += decoded_bytes) { + auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index); + if (auto message = LocalEndpoint::decode_message(remaining_bytes, decoded_bytes)) { + m_local_endpoint.handle(*message); + } else if (auto message = PeerEndpoint::decode_message(remaining_bytes, decoded_bytes)) { + m_unprocessed_messages.append(move(message)); + } else { + ASSERT_NOT_REACHED(); + } + ASSERT(decoded_bytes); + } + return true; } + LocalEndpoint& m_local_endpoint; RefPtr m_connection; RefPtr m_notifier; Vector> m_unprocessed_messages; diff --git a/Libraries/LibCore/CoreIPCServer.h b/Libraries/LibCore/CoreIPCServer.h index e4b7a0ae5b..af65613fe4 100644 --- a/Libraries/LibCore/CoreIPCServer.h +++ b/Libraries/LibCore/CoreIPCServer.h @@ -256,7 +256,7 @@ namespace Server { , m_client_id(client_id) { add_child(socket); - m_socket->on_ready_to_read = [this] { drain_client(); }; + m_socket->on_ready_to_read = [this] { drain_messages_from_client(); }; } virtual ~ConnectionNG() override @@ -287,15 +287,16 @@ namespace Server { ASSERT(nwritten == buffer.size()); } - void drain_client() + void drain_messages_from_client() { - unsigned messages_received = 0; + Vector bytes; for (;;) { u8 buffer[4096]; ssize_t nread = recv(m_socket->fd(), buffer, sizeof(buffer), MSG_DONTWAIT); if (nread == 0 || (nread == -1 && errno == EAGAIN)) { - if (!messages_received) { + if (bytes.is_empty()) { CEventLoop::current().post_event(*this, make(client_id())); + return; } break; } @@ -303,17 +304,21 @@ namespace Server { perror("recv"); ASSERT_NOT_REACHED(); } - auto message = m_endpoint.decode_message(ByteBuffer::wrap(buffer, nread)); + bytes.append(buffer, nread); + } + + size_t decoded_bytes = 0; + for (size_t index = 0; index < (size_t)bytes.size(); index += decoded_bytes) { + auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index); + auto message = Endpoint::decode_message(remaining_bytes, decoded_bytes); if (!message) { - dbg() << "drain_client: Endpoint didn't recognize message"; + dbg() << "drain_messages_from_client: Endpoint didn't recognize message"; did_misbehave(); return; } - ++messages_received; - - auto response = m_endpoint.handle(*message); - if (response) + if (auto response = m_endpoint.handle(*message)) post_message(*response); + ASSERT(decoded_bytes); } } diff --git a/Libraries/LibIPC/IEndpoint.h b/Libraries/LibIPC/IEndpoint.h index e0098f773d..1468f68c32 100644 --- a/Libraries/LibIPC/IEndpoint.h +++ b/Libraries/LibIPC/IEndpoint.h @@ -13,6 +13,7 @@ class IEndpoint { public: virtual ~IEndpoint(); + virtual int magic() const = 0; virtual String name() const = 0; virtual OwnPtr handle(const IMessage&) = 0; diff --git a/Libraries/LibIPC/IMessage.h b/Libraries/LibIPC/IMessage.h index 6a0a81b1e1..813b4f528e 100644 --- a/Libraries/LibIPC/IMessage.h +++ b/Libraries/LibIPC/IMessage.h @@ -7,6 +7,7 @@ class IMessage { public: virtual ~IMessage(); + virtual int endpoint_magic() const = 0; virtual int id() const = 0; virtual String name() const = 0; virtual ByteBuffer encode() const = 0; diff --git a/Servers/AudioServer/ASClientConnection.cpp b/Servers/AudioServer/ASClientConnection.cpp index 47595c458d..e6704db19c 100644 --- a/Servers/AudioServer/ASClientConnection.cpp +++ b/Servers/AudioServer/ASClientConnection.cpp @@ -1,10 +1,9 @@ #include "ASClientConnection.h" #include "ASMixer.h" - +#include "AudioClientEndpoint.h" #include #include #include - #include #include #include @@ -30,10 +29,9 @@ void ASClientConnection::die() s_connections.remove(client_id()); } -void ASClientConnection::did_finish_playing_buffer(Badge, int buffer_id) +void ASClientConnection::did_finish_playing_buffer(Badge, int buffer_id) { - (void)buffer_id; - //post_message(AudioClient::FinishedPlayingBuffer(buffer_id)); + post_message(AudioClient::FinishedPlayingBuffer(buffer_id)); } OwnPtr ASClientConnection::handle(const AudioServer::Greet& message) diff --git a/Servers/AudioServer/ASClientConnection.h b/Servers/AudioServer/ASClientConnection.h index b23886dcba..4147a231d0 100644 --- a/Servers/AudioServer/ASClientConnection.h +++ b/Servers/AudioServer/ASClientConnection.h @@ -13,7 +13,7 @@ class ASClientConnection final : public IPC::Server::ConnectionNG, int buffer_id); + void did_finish_playing_buffer(Badge, int buffer_id); virtual void die() override; diff --git a/Servers/AudioServer/ASMixer.h b/Servers/AudioServer/ASMixer.h index 960dce9996..5de1b624fc 100644 --- a/Servers/AudioServer/ASMixer.h +++ b/Servers/AudioServer/ASMixer.h @@ -1,5 +1,6 @@ #pragma once +#include "ASClientConnection.h" #include #include #include @@ -36,6 +37,7 @@ public: ++m_played_samples; if (m_position >= m_current->sample_count()) { + m_client->did_finish_playing_buffer({}, m_current->shared_buffer_id()); m_current = nullptr; m_position = 0; } @@ -61,8 +63,10 @@ public: int get_remaining_samples() const { return m_remaining_samples; } int get_played_samples() const { return m_played_samples; } - int get_playing_buffer() const { - if(m_current) return m_current->shared_buffer_id(); + int get_playing_buffer() const + { + if (m_current) + return m_current->shared_buffer_id(); return -1; } diff --git a/Servers/AudioServer/AudioClient.ipc b/Servers/AudioServer/AudioClient.ipc index 52a87281b0..ed9ba77b3a 100644 --- a/Servers/AudioServer/AudioClient.ipc +++ b/Servers/AudioServer/AudioClient.ipc @@ -1,4 +1,4 @@ -endpoint AudioClient +endpoint AudioClient = 82 { FinishedPlayingBuffer(i32 buffer_id) =| } diff --git a/Servers/AudioServer/AudioServer.ipc b/Servers/AudioServer/AudioServer.ipc index 1bdc7c8091..b48cbc6c06 100644 --- a/Servers/AudioServer/AudioServer.ipc +++ b/Servers/AudioServer/AudioServer.ipc @@ -1,4 +1,4 @@ -endpoint AudioServer +endpoint AudioServer = 85 { // Basic protocol Greet(i32 client_pid) => (i32 server_pid, i32 client_id) diff --git a/Servers/AudioServer/Makefile b/Servers/AudioServer/Makefile index 9cf71eeb0f..ebdf331a8b 100644 --- a/Servers/AudioServer/Makefile +++ b/Servers/AudioServer/Makefile @@ -13,11 +13,14 @@ DEFINES += -DUSERLAND all: $(APP) -*.cpp: AudioServerEndpoint.h +*.cpp: AudioServerEndpoint.h AudioClientEndpoint.h AudioServerEndpoint.h: AudioServer.ipc @echo "IPC $<"; $(IPCCOMPILER) $< > $@ +AudioClientEndpoint.h: AudioClient.ipc + @echo "IPC $<"; $(IPCCOMPILER) $< > $@ + $(APP): $(OBJS) $(LD) -o $(APP) $(LDFLAGS) $(OBJS) -lc -lcore -lipc -lthread -lpthread @@ -27,5 +30,5 @@ $(APP): $(OBJS) -include $(OBJS:%.o=%.d) clean: - @echo "CLEAN"; rm -f $(APP) $(OBJS) *.d AudioServerEndpoint.h + @echo "CLEAN"; rm -f $(APP) $(OBJS) *.d AudioServerEndpoint.h AudioClientEndpoint.h