diff --git a/Kernel/FileDescriptor.h b/Kernel/FileDescriptor.h index 0dc15af07d..88066ba7fc 100644 --- a/Kernel/FileDescriptor.h +++ b/Kernel/FileDescriptor.h @@ -78,6 +78,8 @@ public: void set_original_inode(Badge, RetainPtr&& inode) { m_inode = move(inode); } + void set_socket_role(SocketRole role) { m_socket_role = role; } + private: friend class VFS; FileDescriptor(RetainPtr&&, SocketRole); diff --git a/Kernel/LocalSocket.cpp b/Kernel/LocalSocket.cpp index 6a0eabc435..3b7f292912 100644 --- a/Kernel/LocalSocket.cpp +++ b/Kernel/LocalSocket.cpp @@ -31,7 +31,7 @@ bool LocalSocket::get_address(sockaddr* address, socklen_t* address_size) bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& error) { - ASSERT(!m_connected); + ASSERT(!is_connected()); if (address_size != sizeof(sockaddr_un)) { error = -EINVAL; return false; @@ -45,7 +45,7 @@ bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& err char safe_address[sizeof(local_address.sun_path) + 1]; memcpy(safe_address, local_address.sun_path, sizeof(local_address.sun_path)); - kprintf("%s(%u) LocalSocket{%p} bind(%s)\n", current->name().characters(), current->pid(), safe_address); + kprintf("%s(%u) LocalSocket{%p} bind(%s)\n", current->name().characters(), current->pid(), this, safe_address); m_file = VFS::the().open(safe_address, error, O_CREAT | O_EXCL, S_IFSOCK | 0666, *current->cwd_inode()); if (!m_file) { @@ -62,37 +62,48 @@ bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& err return true; } -RetainPtr LocalSocket::connect(const sockaddr* address, socklen_t address_size, int& error) +bool LocalSocket::connect(const sockaddr* address, socklen_t address_size, int& error) { ASSERT(!m_bound); if (address_size != sizeof(sockaddr_un)) { error = -EINVAL; - return nullptr; + return false; } if (address->sa_family != AF_LOCAL) { error = -EINVAL; - return nullptr; + return false; } const sockaddr_un& local_address = *reinterpret_cast(address); char safe_address[sizeof(local_address.sun_path) + 1]; memcpy(safe_address, local_address.sun_path, sizeof(local_address.sun_path)); - kprintf("%s(%u) LocalSocket{%p} connect(%s)\n", current->name().characters(), current->pid(), safe_address); + kprintf("%s(%u) LocalSocket{%p} connect(%s)\n", current->name().characters(), current->pid(), this, safe_address); m_file = VFS::the().open(safe_address, error, 0, 0, *current->cwd_inode()); if (!m_file) { error = -ECONNREFUSED; - return nullptr; + return false; + } + ASSERT(m_file->inode()); + if (!m_file->inode()->socket()) { + error = -ECONNREFUSED; + return false; } - ASSERT(m_file->inode()); - ASSERT(m_file->inode()->socket()); - - m_peer = m_file->inode()->socket(); m_address = local_address; - m_connected = true; - return m_peer; + + auto peer = m_file->inode()->socket(); + kprintf("Queueing up connection\n"); + if (!peer->queue_connection_from(*this, error)) + return false; + + kprintf("Waiting for connect...\n"); + if (!current->wait_for_connect(*this, error)) + return false; + + kprintf("CONNECTED!\n"); + return true; } bool LocalSocket::can_read(SocketRole role) const @@ -125,7 +136,7 @@ ssize_t LocalSocket::write(SocketRole role, const byte* data, size_t size) bool LocalSocket::can_write(SocketRole role) const { if (role == SocketRole::Accepted) - return !m_for_client.is_empty(); + return m_for_client.bytes_in_write_buffer() < 4096; else - return !m_for_server.is_empty(); + return m_for_server.bytes_in_write_buffer() < 4096; } diff --git a/Kernel/LocalSocket.h b/Kernel/LocalSocket.h index 76e58c92ad..f4581a67a0 100644 --- a/Kernel/LocalSocket.h +++ b/Kernel/LocalSocket.h @@ -11,7 +11,7 @@ public: virtual ~LocalSocket() override; virtual bool bind(const sockaddr*, socklen_t, int& error) override; - virtual RetainPtr connect(const sockaddr*, socklen_t, int& error) override; + virtual bool connect(const sockaddr*, socklen_t, int& error) override; virtual bool get_address(sockaddr*, socklen_t*) override; virtual bool can_read(SocketRole) const override; @@ -27,7 +27,6 @@ private: RetainPtr m_peer; bool m_bound { false }; - bool m_connected { false }; sockaddr_un m_address; DoubleBuffer m_for_client; diff --git a/Kernel/Process.cpp b/Kernel/Process.cpp index 234a92de27..39ce65120a 100644 --- a/Kernel/Process.cpp +++ b/Kernel/Process.cpp @@ -1784,6 +1784,13 @@ int Process::sys$ioctl(int fd, unsigned request, unsigned arg) auto* descriptor = file_descriptor(fd); if (!descriptor) return -EBADF; + if (descriptor->is_socket() && request == 413) { + auto* pid = (pid_t*)arg; + if (!validate_write_typed(pid)) + return -EFAULT; + *pid = descriptor->socket()->origin_pid(); + return 0; + } if (!descriptor->is_character_device()) return -ENOTTY; return descriptor->character_device()->ioctl(*this, request, arg); @@ -2347,10 +2354,23 @@ int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_ return -ENOTSOCK; auto& socket = *descriptor->socket(); int error; - auto server = socket.connect(address, address_size, error); - if (!server) + if (!socket.connect(address, address_size, error)) return error; - auto server_descriptor = FileDescriptor::create(move(server), SocketRole::Connected); - m_fds[fd].set(move(server_descriptor)); - return fd; + descriptor->set_socket_role(SocketRole::Connected); + return 0; +} + +bool Process::wait_for_connect(Socket& socket, int& error) +{ + if (socket.is_connected()) + return true; + m_blocked_connecting_socket = socket; + block(BlockedConnect); + Scheduler::yield(); + m_blocked_connecting_socket = nullptr; + if (!socket.is_connected()) { + error = -ECONNREFUSED; + return false; + } + return true; } diff --git a/Kernel/Process.h b/Kernel/Process.h index 4ea660fd53..3894c78869 100644 --- a/Kernel/Process.h +++ b/Kernel/Process.h @@ -78,6 +78,7 @@ public: BlockedWrite, BlockedSignal, BlockedSelect, + BlockedConnect, }; enum Priority { @@ -226,6 +227,8 @@ public: DisplayInfo set_video_resolution(int width, int height); + bool wait_for_connect(Socket&, int& error); + static void initialize(); void crash() NORETURN; @@ -356,6 +359,7 @@ private: SignalActionData m_signal_action_data[32]; dword m_pending_signals { 0 }; dword m_signal_mask { 0xffffffff }; + RetainPtr m_blocked_connecting_socket; byte m_termination_status { 0 }; byte m_termination_signal { 0 }; @@ -456,6 +460,7 @@ static inline const char* to_string(Process::State state) case Process::BlockedSignal: return "Signal"; case Process::BlockedSelect: return "Select"; case Process::BlockedLurking: return "Lurking"; + case Process::BlockedConnect: return "Connect"; case Process::BeingInspected: return "Inspect"; } ASSERT_NOT_REACHED(); diff --git a/Kernel/Scheduler.cpp b/Kernel/Scheduler.cpp index 22b6deddd5..07e199cd8f 100644 --- a/Kernel/Scheduler.cpp +++ b/Kernel/Scheduler.cpp @@ -91,6 +91,13 @@ bool Scheduler::pick_next() return true; } + if (process.state() == Process::BlockedConnect) { + ASSERT(process.m_blocked_connecting_socket); + if (process.m_blocked_connecting_socket->is_connected()) + process.unblock(); + return true; + } + if (process.state() == Process::BlockedSelect) { if (process.wakeup_requested()) { process.m_wakeup_requested = false; diff --git a/Kernel/Socket.cpp b/Kernel/Socket.cpp index 18fb351779..129b44fdeb 100644 --- a/Kernel/Socket.cpp +++ b/Kernel/Socket.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include RetainPtr Socket::create(int domain, int type, int protocol, int& error) @@ -20,6 +21,7 @@ Socket::Socket(int domain, int type, int protocol) , m_type(type) , m_protocol(protocol) { + m_origin_pid = current->pid(); } Socket::~Socket() @@ -45,6 +47,19 @@ RetainPtr Socket::accept() if (m_pending.is_empty()) return nullptr; auto client = m_pending.take_first(); + ASSERT(!client->is_connected()); + client->m_connected = true; m_clients.append(client.copy_ref()); return client; } + +bool Socket::queue_connection_from(Socket& peer, int& error) +{ + LOCKER(m_lock); + if (m_pending.size() >= m_backlog) { + error = -ECONNREFUSED; + return false; + } + m_pending.append(peer); + return true; +} diff --git a/Kernel/Socket.h b/Kernel/Socket.h index 56e91a64b5..f042661c91 100644 --- a/Kernel/Socket.h +++ b/Kernel/Socket.h @@ -19,13 +19,13 @@ public: int type() const { return m_type; } int protocol() const { return m_protocol; } - bool can_accept() const { return m_pending.is_empty(); } + bool can_accept() const { return !m_pending.is_empty(); } RetainPtr accept(); - + bool is_connected() const { return m_connected; } bool listen(int backlog, int& error); virtual bool bind(const sockaddr*, socklen_t, int& error) = 0; - virtual RetainPtr connect(const sockaddr*, socklen_t, int& error) = 0; + virtual bool connect(const sockaddr*, socklen_t, int& error) = 0; virtual bool get_address(sockaddr*, socklen_t*) = 0; virtual bool is_local() const { return false; } @@ -34,16 +34,22 @@ public: virtual ssize_t write(SocketRole, const byte*, size_t) = 0; virtual bool can_write(SocketRole) const = 0; + pid_t origin_pid() const { return m_origin_pid; } + protected: Socket(int domain, int type, int protocol); + bool queue_connection_from(Socket&, int& error); + private: Lock m_lock; + pid_t m_origin_pid { 0 }; int m_domain { 0 }; int m_type { 0 }; int m_protocol { 0 }; int m_backlog { 0 }; bool m_listening { false }; + bool m_connected { false }; Vector> m_pending; Vector> m_clients; diff --git a/LibGUI/GEventLoop.cpp b/LibGUI/GEventLoop.cpp index f7dfe73555..9ff3490ac2 100644 --- a/LibGUI/GEventLoop.cpp +++ b/LibGUI/GEventLoop.cpp @@ -11,8 +11,10 @@ #include #include #include -#include -#include +#include +#include +#include +#include //#define GEVENTLOOP_DEBUG @@ -28,10 +30,20 @@ GEventLoop::GEventLoop() if (!s_mainGEventLoop) s_mainGEventLoop = this; - m_event_fd = open("/dev/gui_events", O_RDWR | O_NONBLOCK | O_CLOEXEC); + m_event_fd = socket(AF_LOCAL, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0); if (m_event_fd < 0) { - perror("GEventLoop(): open"); - exit(1); + perror("socket"); + ASSERT_NOT_REACHED(); + } + + sockaddr_un address; + address.sun_family = AF_LOCAL; + strcpy(address.sun_path, "/wsportal"); + int rc = connect(m_event_fd, (const sockaddr*)&address, sizeof(address)); + if (rc < 0) { + dbgprintf("connect failed: %d, %s\n", errno, strerror(errno)); + perror("connect"); + ASSERT_NOT_REACHED(); } } diff --git a/WindowServer/WSClientConnection.cpp b/WindowServer/WSClientConnection.cpp index 2fa49bffbf..5a47544408 100644 --- a/WindowServer/WSClientConnection.cpp +++ b/WindowServer/WSClientConnection.cpp @@ -7,9 +7,19 @@ #include #include #include +#include Lockable>* s_connections; +void WSClientConnection::for_each_client(Function callback) +{ + if (!s_connections) + return; + LOCKER(s_connections->lock()); + for (auto& it : s_connections->resource()) { + callback(*it.value); + } +} WSClientConnection* WSClientConnection::from_client_id(int client_id) { @@ -29,15 +39,25 @@ WSClientConnection* WSClientConnection::ensure_for_client_id(int client_id) return new WSClientConnection(client_id); } -WSClientConnection::WSClientConnection(int client_id) - : m_client_id(client_id) +WSClientConnection::WSClientConnection(int fd) + : m_fd(fd) { + pid_t pid; + int rc = WSMessageLoop::the().server_process().sys$ioctl(m_fd, 413, (int)&pid); + ASSERT(rc == 0); + + { + InterruptDisabler disabler; + auto* process = Process::from_pid(pid); + ASSERT(process); + m_process = process->make_weak_ptr(); + m_client_id = (int)process; + } + if (!s_connections) s_connections = new Lockable>; LOCKER(s_connections->lock()); - s_connections->resource().set(client_id, this); - - m_process = ((Process*)m_client_id)->make_weak_ptr(); + s_connections->resource().set(m_client_id, this); } WSClientConnection::~WSClientConnection() @@ -57,12 +77,10 @@ void WSClientConnection::post_error(const String& error_message) WSMessageLoop::the().post_message_to_client(m_client_id, message); } -void WSClientConnection::post_message(GUI_ServerMessage&& message) +void WSClientConnection::post_message(const GUI_ServerMessage& message) { - if (!m_process) - return; - LOCKER(m_process->gui_events_lock()); - m_process->gui_events().append(move(message)); + int nwritten = WSMessageLoop::the().server_process().sys$write(m_fd, &message, sizeof(message)); + ASSERT(nwritten == sizeof(message)); } RetainPtr WSClientConnection::create_bitmap(const Size& size) diff --git a/WindowServer/WSClientConnection.h b/WindowServer/WSClientConnection.h index e40e6b2090..34e4f884e4 100644 --- a/WindowServer/WSClientConnection.h +++ b/WindowServer/WSClientConnection.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -17,18 +18,21 @@ class Process; class WSClientConnection final : public WSMessageReceiver { public: - explicit WSClientConnection(int client_id); + explicit WSClientConnection(int fd); virtual ~WSClientConnection() override; static WSClientConnection* from_client_id(int client_id); static WSClientConnection* ensure_for_client_id(int client_id); + static void for_each_client(Function); - void post_message(GUI_ServerMessage&&); + void post_message(const GUI_ServerMessage&); RetainPtr create_bitmap(const Size&); int client_id() const { return m_client_id; } WSMenuBar* app_menubar() { return m_app_menubar.ptr(); } + int fd() const { return m_fd; } + private: virtual void on_message(WSMessage&) override; @@ -56,6 +60,7 @@ private: void post_error(const String&); int m_client_id { 0 }; + int m_fd { -1 }; HashMap> m_windows; HashMap> m_menubars; diff --git a/WindowServer/WSMenu.cpp b/WindowServer/WSMenu.cpp index dfe841e5a3..4cef867ae0 100644 --- a/WindowServer/WSMenu.cpp +++ b/WindowServer/WSMenu.cpp @@ -144,7 +144,7 @@ void WSMenu::did_activate(WSMenuItem& item) message.menu.identifier = item.identifier(); if (auto* client = WSClientConnection::from_client_id(m_client_id)) - client->post_message(move(message)); + client->post_message(message); } WSMenuItem* WSMenu::item_at(const Point& position) diff --git a/WindowServer/WSMessageLoop.cpp b/WindowServer/WSMessageLoop.cpp index 1898c55d8d..c554220186 100644 --- a/WindowServer/WSMessageLoop.cpp +++ b/WindowServer/WSMessageLoop.cpp @@ -37,6 +37,18 @@ int WSMessageLoop::exec() m_keyboard_fd = m_server_process->sys$open("/dev/keyboard", O_RDONLY); m_mouse_fd = m_server_process->sys$open("/dev/psaux", O_RDONLY); + m_server_process->sys$unlink("/wsportal"); + + m_server_fd = m_server_process->sys$socket(AF_LOCAL, SOCK_STREAM, 0); + ASSERT(m_server_fd >= 0); + sockaddr_un address; + address.sun_family = AF_LOCAL; + strcpy(address.sun_path, "/wsportal"); + int rc = m_server_process->sys$bind(m_server_fd, (const sockaddr*)&address, sizeof(address)); + ASSERT(rc == 0); + rc = m_server_process->sys$listen(m_server_fd, 5); + ASSERT(rc == 0); + ASSERT(m_keyboard_fd >= 0); ASSERT(m_mouse_fd >= 0); @@ -76,11 +88,10 @@ Process* WSMessageLoop::process_from_client_id(int client_id) void WSMessageLoop::post_message_to_client(int client_id, const GUI_ServerMessage& message) { - auto* process = process_from_client_id(client_id); - if (!process) + auto* client = WSClientConnection::from_client_id(client_id); + if (!client) return; - LOCKER(process->gui_events_lock()); - process->gui_events().append(move(message)); + client->post_message(message); } void WSMessageLoop::post_message(WSMessageReceiver* receiver, OwnPtr&& message) @@ -164,10 +175,23 @@ void WSMessageLoop::wait_for_message() fd_set rfds; memset(&rfds, 0, sizeof(rfds)); auto bitmap = Bitmap::wrap((byte*)&rfds, FD_SETSIZE); - bitmap.set(m_keyboard_fd, true); - bitmap.set(m_mouse_fd, true); + int max_fd = 0; + auto add_fd_to_set = [&max_fd] (int fd, auto& bitmap) { + bitmap.set(fd, true); + if (fd > max_fd) + max_fd = fd; + }; + + add_fd_to_set(m_keyboard_fd, bitmap); + add_fd_to_set(m_mouse_fd, bitmap); + add_fd_to_set(m_server_fd, bitmap); + + WSClientConnection::for_each_client([&] (WSClientConnection& client) { + add_fd_to_set(client.fd(), bitmap); + }); + Syscall::SC_select_params params; - params.nfds = max(m_keyboard_fd, m_mouse_fd) + 1; + params.nfds = max_fd + 1; params.readfds = &rfds; params.writefds = nullptr; params.exceptfds = nullptr; @@ -210,6 +234,20 @@ void WSMessageLoop::wait_for_message() drain_keyboard(); if (bitmap.get(m_mouse_fd)) drain_mouse(); + if (bitmap.get(m_server_fd)) { + sockaddr_un address; + socklen_t address_size = sizeof(address); + int client_fd = m_server_process->sys$accept(m_server_fd, (sockaddr*)&address, &address_size); + kprintf("accept() returned fd=%d, address=%s\n", client_fd, address.sun_path); + new WSClientConnection(client_fd); + } + WSClientConnection::for_each_client([&] (WSClientConnection& client) { + if (bitmap.get(client.fd())) { + byte buffer[4096]; + ssize_t nread = m_server_process->sys$read(client.fd(), buffer, sizeof(GUI_ClientMessage)); + on_receive_from_client(client.client_id(), buffer, nread); + } + }); } void WSMessageLoop::drain_mouse() diff --git a/WindowServer/WSMessageLoop.h b/WindowServer/WSMessageLoop.h index b79cf853bf..c0bd7abf83 100644 --- a/WindowServer/WSMessageLoop.h +++ b/WindowServer/WSMessageLoop.h @@ -54,6 +54,7 @@ private: int m_keyboard_fd { -1 }; int m_mouse_fd { -1 }; + int m_server_fd { -1 }; struct Timer { void reload(); diff --git a/WindowServer/WSWindow.cpp b/WindowServer/WSWindow.cpp index 9ca715e391..58a13d1ad3 100644 --- a/WindowServer/WSWindow.cpp +++ b/WindowServer/WSWindow.cpp @@ -126,7 +126,7 @@ void WSWindow::on_message(WSMessage& message) return; if (auto* client = WSClientConnection::from_client_id(m_client_id)) - client->post_message(move(server_message)); + client->post_message(server_message); } void WSWindow::set_global_cursor_tracking_enabled(bool enabled)