From bf58241c11abf25ef4c7ed6be559089ed1fe1dc4 Mon Sep 17 00:00:00 2001 From: Andreas Kling Date: Thu, 14 Feb 2019 17:18:35 +0100 Subject: [PATCH] Port the WindowServer and LibGUI to communicate through local sockets. This is really cool! :^) Apps currently refuse to start if the WindowServer isn't listening on the socket in /wsportal. This makes sense, but I guess it would also be nice to have some sort of "wait for server on startup" mode. This has performance issues, and I'll work on those, but this stuff seems to actually work and I'm very happy with that. --- Kernel/FileDescriptor.h | 2 ++ Kernel/LocalSocket.cpp | 41 ++++++++++++++--------- Kernel/LocalSocket.h | 3 +- Kernel/Process.cpp | 30 ++++++++++++++--- Kernel/Process.h | 5 +++ Kernel/Scheduler.cpp | 7 ++++ Kernel/Socket.cpp | 15 +++++++++ Kernel/Socket.h | 12 +++++-- LibGUI/GEventLoop.cpp | 22 +++++++++--- WindowServer/WSClientConnection.cpp | 38 +++++++++++++++------ WindowServer/WSClientConnection.h | 9 +++-- WindowServer/WSMenu.cpp | 2 +- WindowServer/WSMessageLoop.cpp | 52 +++++++++++++++++++++++++---- WindowServer/WSMessageLoop.h | 1 + WindowServer/WSWindow.cpp | 2 +- 15 files changed, 190 insertions(+), 51 deletions(-) diff --git a/Kernel/FileDescriptor.h b/Kernel/FileDescriptor.h index 0dc15af07d..88066ba7fc 100644 --- a/Kernel/FileDescriptor.h +++ b/Kernel/FileDescriptor.h @@ -78,6 +78,8 @@ public: void set_original_inode(Badge, RetainPtr&& inode) { m_inode = move(inode); } + void set_socket_role(SocketRole role) { m_socket_role = role; } + private: friend class VFS; FileDescriptor(RetainPtr&&, SocketRole); diff --git a/Kernel/LocalSocket.cpp b/Kernel/LocalSocket.cpp index 6a0eabc435..3b7f292912 100644 --- a/Kernel/LocalSocket.cpp +++ b/Kernel/LocalSocket.cpp @@ -31,7 +31,7 @@ bool LocalSocket::get_address(sockaddr* address, socklen_t* address_size) bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& error) { - ASSERT(!m_connected); + ASSERT(!is_connected()); if (address_size != sizeof(sockaddr_un)) { error = -EINVAL; return false; @@ -45,7 +45,7 @@ bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& err char safe_address[sizeof(local_address.sun_path) + 1]; memcpy(safe_address, local_address.sun_path, sizeof(local_address.sun_path)); - kprintf("%s(%u) LocalSocket{%p} bind(%s)\n", current->name().characters(), current->pid(), safe_address); + kprintf("%s(%u) LocalSocket{%p} bind(%s)\n", current->name().characters(), current->pid(), this, safe_address); m_file = VFS::the().open(safe_address, error, O_CREAT | O_EXCL, S_IFSOCK | 0666, *current->cwd_inode()); if (!m_file) { @@ -62,37 +62,48 @@ bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& err return true; } -RetainPtr LocalSocket::connect(const sockaddr* address, socklen_t address_size, int& error) +bool LocalSocket::connect(const sockaddr* address, socklen_t address_size, int& error) { ASSERT(!m_bound); if (address_size != sizeof(sockaddr_un)) { error = -EINVAL; - return nullptr; + return false; } if (address->sa_family != AF_LOCAL) { error = -EINVAL; - return nullptr; + return false; } const sockaddr_un& local_address = *reinterpret_cast(address); char safe_address[sizeof(local_address.sun_path) + 1]; memcpy(safe_address, local_address.sun_path, sizeof(local_address.sun_path)); - kprintf("%s(%u) LocalSocket{%p} connect(%s)\n", current->name().characters(), current->pid(), safe_address); + kprintf("%s(%u) LocalSocket{%p} connect(%s)\n", current->name().characters(), current->pid(), this, safe_address); m_file = VFS::the().open(safe_address, error, 0, 0, *current->cwd_inode()); if (!m_file) { error = -ECONNREFUSED; - return nullptr; + return false; + } + ASSERT(m_file->inode()); + if (!m_file->inode()->socket()) { + error = -ECONNREFUSED; + return false; } - ASSERT(m_file->inode()); - ASSERT(m_file->inode()->socket()); - - m_peer = m_file->inode()->socket(); m_address = local_address; - m_connected = true; - return m_peer; + + auto peer = m_file->inode()->socket(); + kprintf("Queueing up connection\n"); + if (!peer->queue_connection_from(*this, error)) + return false; + + kprintf("Waiting for connect...\n"); + if (!current->wait_for_connect(*this, error)) + return false; + + kprintf("CONNECTED!\n"); + return true; } bool LocalSocket::can_read(SocketRole role) const @@ -125,7 +136,7 @@ ssize_t LocalSocket::write(SocketRole role, const byte* data, size_t size) bool LocalSocket::can_write(SocketRole role) const { if (role == SocketRole::Accepted) - return !m_for_client.is_empty(); + return m_for_client.bytes_in_write_buffer() < 4096; else - return !m_for_server.is_empty(); + return m_for_server.bytes_in_write_buffer() < 4096; } diff --git a/Kernel/LocalSocket.h b/Kernel/LocalSocket.h index 76e58c92ad..f4581a67a0 100644 --- a/Kernel/LocalSocket.h +++ b/Kernel/LocalSocket.h @@ -11,7 +11,7 @@ public: virtual ~LocalSocket() override; virtual bool bind(const sockaddr*, socklen_t, int& error) override; - virtual RetainPtr connect(const sockaddr*, socklen_t, int& error) override; + virtual bool connect(const sockaddr*, socklen_t, int& error) override; virtual bool get_address(sockaddr*, socklen_t*) override; virtual bool can_read(SocketRole) const override; @@ -27,7 +27,6 @@ private: RetainPtr m_peer; bool m_bound { false }; - bool m_connected { false }; sockaddr_un m_address; DoubleBuffer m_for_client; diff --git a/Kernel/Process.cpp b/Kernel/Process.cpp index 234a92de27..39ce65120a 100644 --- a/Kernel/Process.cpp +++ b/Kernel/Process.cpp @@ -1784,6 +1784,13 @@ int Process::sys$ioctl(int fd, unsigned request, unsigned arg) auto* descriptor = file_descriptor(fd); if (!descriptor) return -EBADF; + if (descriptor->is_socket() && request == 413) { + auto* pid = (pid_t*)arg; + if (!validate_write_typed(pid)) + return -EFAULT; + *pid = descriptor->socket()->origin_pid(); + return 0; + } if (!descriptor->is_character_device()) return -ENOTTY; return descriptor->character_device()->ioctl(*this, request, arg); @@ -2347,10 +2354,23 @@ int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_ return -ENOTSOCK; auto& socket = *descriptor->socket(); int error; - auto server = socket.connect(address, address_size, error); - if (!server) + if (!socket.connect(address, address_size, error)) return error; - auto server_descriptor = FileDescriptor::create(move(server), SocketRole::Connected); - m_fds[fd].set(move(server_descriptor)); - return fd; + descriptor->set_socket_role(SocketRole::Connected); + return 0; +} + +bool Process::wait_for_connect(Socket& socket, int& error) +{ + if (socket.is_connected()) + return true; + m_blocked_connecting_socket = socket; + block(BlockedConnect); + Scheduler::yield(); + m_blocked_connecting_socket = nullptr; + if (!socket.is_connected()) { + error = -ECONNREFUSED; + return false; + } + return true; } diff --git a/Kernel/Process.h b/Kernel/Process.h index 4ea660fd53..3894c78869 100644 --- a/Kernel/Process.h +++ b/Kernel/Process.h @@ -78,6 +78,7 @@ public: BlockedWrite, BlockedSignal, BlockedSelect, + BlockedConnect, }; enum Priority { @@ -226,6 +227,8 @@ public: DisplayInfo set_video_resolution(int width, int height); + bool wait_for_connect(Socket&, int& error); + static void initialize(); void crash() NORETURN; @@ -356,6 +359,7 @@ private: SignalActionData m_signal_action_data[32]; dword m_pending_signals { 0 }; dword m_signal_mask { 0xffffffff }; + RetainPtr m_blocked_connecting_socket; byte m_termination_status { 0 }; byte m_termination_signal { 0 }; @@ -456,6 +460,7 @@ static inline const char* to_string(Process::State state) case Process::BlockedSignal: return "Signal"; case Process::BlockedSelect: return "Select"; case Process::BlockedLurking: return "Lurking"; + case Process::BlockedConnect: return "Connect"; case Process::BeingInspected: return "Inspect"; } ASSERT_NOT_REACHED(); diff --git a/Kernel/Scheduler.cpp b/Kernel/Scheduler.cpp index 22b6deddd5..07e199cd8f 100644 --- a/Kernel/Scheduler.cpp +++ b/Kernel/Scheduler.cpp @@ -91,6 +91,13 @@ bool Scheduler::pick_next() return true; } + if (process.state() == Process::BlockedConnect) { + ASSERT(process.m_blocked_connecting_socket); + if (process.m_blocked_connecting_socket->is_connected()) + process.unblock(); + return true; + } + if (process.state() == Process::BlockedSelect) { if (process.wakeup_requested()) { process.m_wakeup_requested = false; diff --git a/Kernel/Socket.cpp b/Kernel/Socket.cpp index 18fb351779..129b44fdeb 100644 --- a/Kernel/Socket.cpp +++ b/Kernel/Socket.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include RetainPtr Socket::create(int domain, int type, int protocol, int& error) @@ -20,6 +21,7 @@ Socket::Socket(int domain, int type, int protocol) , m_type(type) , m_protocol(protocol) { + m_origin_pid = current->pid(); } Socket::~Socket() @@ -45,6 +47,19 @@ RetainPtr Socket::accept() if (m_pending.is_empty()) return nullptr; auto client = m_pending.take_first(); + ASSERT(!client->is_connected()); + client->m_connected = true; m_clients.append(client.copy_ref()); return client; } + +bool Socket::queue_connection_from(Socket& peer, int& error) +{ + LOCKER(m_lock); + if (m_pending.size() >= m_backlog) { + error = -ECONNREFUSED; + return false; + } + m_pending.append(peer); + return true; +} diff --git a/Kernel/Socket.h b/Kernel/Socket.h index 56e91a64b5..f042661c91 100644 --- a/Kernel/Socket.h +++ b/Kernel/Socket.h @@ -19,13 +19,13 @@ public: int type() const { return m_type; } int protocol() const { return m_protocol; } - bool can_accept() const { return m_pending.is_empty(); } + bool can_accept() const { return !m_pending.is_empty(); } RetainPtr accept(); - + bool is_connected() const { return m_connected; } bool listen(int backlog, int& error); virtual bool bind(const sockaddr*, socklen_t, int& error) = 0; - virtual RetainPtr connect(const sockaddr*, socklen_t, int& error) = 0; + virtual bool connect(const sockaddr*, socklen_t, int& error) = 0; virtual bool get_address(sockaddr*, socklen_t*) = 0; virtual bool is_local() const { return false; } @@ -34,16 +34,22 @@ public: virtual ssize_t write(SocketRole, const byte*, size_t) = 0; virtual bool can_write(SocketRole) const = 0; + pid_t origin_pid() const { return m_origin_pid; } + protected: Socket(int domain, int type, int protocol); + bool queue_connection_from(Socket&, int& error); + private: Lock m_lock; + pid_t m_origin_pid { 0 }; int m_domain { 0 }; int m_type { 0 }; int m_protocol { 0 }; int m_backlog { 0 }; bool m_listening { false }; + bool m_connected { false }; Vector> m_pending; Vector> m_clients; diff --git a/LibGUI/GEventLoop.cpp b/LibGUI/GEventLoop.cpp index f7dfe73555..9ff3490ac2 100644 --- a/LibGUI/GEventLoop.cpp +++ b/LibGUI/GEventLoop.cpp @@ -11,8 +11,10 @@ #include #include #include -#include -#include +#include +#include +#include +#include //#define GEVENTLOOP_DEBUG @@ -28,10 +30,20 @@ GEventLoop::GEventLoop() if (!s_mainGEventLoop) s_mainGEventLoop = this; - m_event_fd = open("/dev/gui_events", O_RDWR | O_NONBLOCK | O_CLOEXEC); + m_event_fd = socket(AF_LOCAL, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0); if (m_event_fd < 0) { - perror("GEventLoop(): open"); - exit(1); + perror("socket"); + ASSERT_NOT_REACHED(); + } + + sockaddr_un address; + address.sun_family = AF_LOCAL; + strcpy(address.sun_path, "/wsportal"); + int rc = connect(m_event_fd, (const sockaddr*)&address, sizeof(address)); + if (rc < 0) { + dbgprintf("connect failed: %d, %s\n", errno, strerror(errno)); + perror("connect"); + ASSERT_NOT_REACHED(); } } diff --git a/WindowServer/WSClientConnection.cpp b/WindowServer/WSClientConnection.cpp index 2fa49bffbf..5a47544408 100644 --- a/WindowServer/WSClientConnection.cpp +++ b/WindowServer/WSClientConnection.cpp @@ -7,9 +7,19 @@ #include #include #include +#include Lockable>* s_connections; +void WSClientConnection::for_each_client(Function callback) +{ + if (!s_connections) + return; + LOCKER(s_connections->lock()); + for (auto& it : s_connections->resource()) { + callback(*it.value); + } +} WSClientConnection* WSClientConnection::from_client_id(int client_id) { @@ -29,15 +39,25 @@ WSClientConnection* WSClientConnection::ensure_for_client_id(int client_id) return new WSClientConnection(client_id); } -WSClientConnection::WSClientConnection(int client_id) - : m_client_id(client_id) +WSClientConnection::WSClientConnection(int fd) + : m_fd(fd) { + pid_t pid; + int rc = WSMessageLoop::the().server_process().sys$ioctl(m_fd, 413, (int)&pid); + ASSERT(rc == 0); + + { + InterruptDisabler disabler; + auto* process = Process::from_pid(pid); + ASSERT(process); + m_process = process->make_weak_ptr(); + m_client_id = (int)process; + } + if (!s_connections) s_connections = new Lockable>; LOCKER(s_connections->lock()); - s_connections->resource().set(client_id, this); - - m_process = ((Process*)m_client_id)->make_weak_ptr(); + s_connections->resource().set(m_client_id, this); } WSClientConnection::~WSClientConnection() @@ -57,12 +77,10 @@ void WSClientConnection::post_error(const String& error_message) WSMessageLoop::the().post_message_to_client(m_client_id, message); } -void WSClientConnection::post_message(GUI_ServerMessage&& message) +void WSClientConnection::post_message(const GUI_ServerMessage& message) { - if (!m_process) - return; - LOCKER(m_process->gui_events_lock()); - m_process->gui_events().append(move(message)); + int nwritten = WSMessageLoop::the().server_process().sys$write(m_fd, &message, sizeof(message)); + ASSERT(nwritten == sizeof(message)); } RetainPtr WSClientConnection::create_bitmap(const Size& size) diff --git a/WindowServer/WSClientConnection.h b/WindowServer/WSClientConnection.h index e40e6b2090..34e4f884e4 100644 --- a/WindowServer/WSClientConnection.h +++ b/WindowServer/WSClientConnection.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -17,18 +18,21 @@ class Process; class WSClientConnection final : public WSMessageReceiver { public: - explicit WSClientConnection(int client_id); + explicit WSClientConnection(int fd); virtual ~WSClientConnection() override; static WSClientConnection* from_client_id(int client_id); static WSClientConnection* ensure_for_client_id(int client_id); + static void for_each_client(Function); - void post_message(GUI_ServerMessage&&); + void post_message(const GUI_ServerMessage&); RetainPtr create_bitmap(const Size&); int client_id() const { return m_client_id; } WSMenuBar* app_menubar() { return m_app_menubar.ptr(); } + int fd() const { return m_fd; } + private: virtual void on_message(WSMessage&) override; @@ -56,6 +60,7 @@ private: void post_error(const String&); int m_client_id { 0 }; + int m_fd { -1 }; HashMap> m_windows; HashMap> m_menubars; diff --git a/WindowServer/WSMenu.cpp b/WindowServer/WSMenu.cpp index dfe841e5a3..4cef867ae0 100644 --- a/WindowServer/WSMenu.cpp +++ b/WindowServer/WSMenu.cpp @@ -144,7 +144,7 @@ void WSMenu::did_activate(WSMenuItem& item) message.menu.identifier = item.identifier(); if (auto* client = WSClientConnection::from_client_id(m_client_id)) - client->post_message(move(message)); + client->post_message(message); } WSMenuItem* WSMenu::item_at(const Point& position) diff --git a/WindowServer/WSMessageLoop.cpp b/WindowServer/WSMessageLoop.cpp index 1898c55d8d..c554220186 100644 --- a/WindowServer/WSMessageLoop.cpp +++ b/WindowServer/WSMessageLoop.cpp @@ -37,6 +37,18 @@ int WSMessageLoop::exec() m_keyboard_fd = m_server_process->sys$open("/dev/keyboard", O_RDONLY); m_mouse_fd = m_server_process->sys$open("/dev/psaux", O_RDONLY); + m_server_process->sys$unlink("/wsportal"); + + m_server_fd = m_server_process->sys$socket(AF_LOCAL, SOCK_STREAM, 0); + ASSERT(m_server_fd >= 0); + sockaddr_un address; + address.sun_family = AF_LOCAL; + strcpy(address.sun_path, "/wsportal"); + int rc = m_server_process->sys$bind(m_server_fd, (const sockaddr*)&address, sizeof(address)); + ASSERT(rc == 0); + rc = m_server_process->sys$listen(m_server_fd, 5); + ASSERT(rc == 0); + ASSERT(m_keyboard_fd >= 0); ASSERT(m_mouse_fd >= 0); @@ -76,11 +88,10 @@ Process* WSMessageLoop::process_from_client_id(int client_id) void WSMessageLoop::post_message_to_client(int client_id, const GUI_ServerMessage& message) { - auto* process = process_from_client_id(client_id); - if (!process) + auto* client = WSClientConnection::from_client_id(client_id); + if (!client) return; - LOCKER(process->gui_events_lock()); - process->gui_events().append(move(message)); + client->post_message(message); } void WSMessageLoop::post_message(WSMessageReceiver* receiver, OwnPtr&& message) @@ -164,10 +175,23 @@ void WSMessageLoop::wait_for_message() fd_set rfds; memset(&rfds, 0, sizeof(rfds)); auto bitmap = Bitmap::wrap((byte*)&rfds, FD_SETSIZE); - bitmap.set(m_keyboard_fd, true); - bitmap.set(m_mouse_fd, true); + int max_fd = 0; + auto add_fd_to_set = [&max_fd] (int fd, auto& bitmap) { + bitmap.set(fd, true); + if (fd > max_fd) + max_fd = fd; + }; + + add_fd_to_set(m_keyboard_fd, bitmap); + add_fd_to_set(m_mouse_fd, bitmap); + add_fd_to_set(m_server_fd, bitmap); + + WSClientConnection::for_each_client([&] (WSClientConnection& client) { + add_fd_to_set(client.fd(), bitmap); + }); + Syscall::SC_select_params params; - params.nfds = max(m_keyboard_fd, m_mouse_fd) + 1; + params.nfds = max_fd + 1; params.readfds = &rfds; params.writefds = nullptr; params.exceptfds = nullptr; @@ -210,6 +234,20 @@ void WSMessageLoop::wait_for_message() drain_keyboard(); if (bitmap.get(m_mouse_fd)) drain_mouse(); + if (bitmap.get(m_server_fd)) { + sockaddr_un address; + socklen_t address_size = sizeof(address); + int client_fd = m_server_process->sys$accept(m_server_fd, (sockaddr*)&address, &address_size); + kprintf("accept() returned fd=%d, address=%s\n", client_fd, address.sun_path); + new WSClientConnection(client_fd); + } + WSClientConnection::for_each_client([&] (WSClientConnection& client) { + if (bitmap.get(client.fd())) { + byte buffer[4096]; + ssize_t nread = m_server_process->sys$read(client.fd(), buffer, sizeof(GUI_ClientMessage)); + on_receive_from_client(client.client_id(), buffer, nread); + } + }); } void WSMessageLoop::drain_mouse() diff --git a/WindowServer/WSMessageLoop.h b/WindowServer/WSMessageLoop.h index b79cf853bf..c0bd7abf83 100644 --- a/WindowServer/WSMessageLoop.h +++ b/WindowServer/WSMessageLoop.h @@ -54,6 +54,7 @@ private: int m_keyboard_fd { -1 }; int m_mouse_fd { -1 }; + int m_server_fd { -1 }; struct Timer { void reload(); diff --git a/WindowServer/WSWindow.cpp b/WindowServer/WSWindow.cpp index 9ca715e391..58a13d1ad3 100644 --- a/WindowServer/WSWindow.cpp +++ b/WindowServer/WSWindow.cpp @@ -126,7 +126,7 @@ void WSWindow::on_message(WSMessage& message) return; if (auto* client = WSClientConnection::from_client_id(m_client_id)) - client->post_message(move(server_message)); + client->post_message(server_message); } void WSWindow::set_global_cursor_tracking_enabled(bool enabled)