From cb5f83606ae6666d698fd111224c09a37e7ebdf5 Mon Sep 17 00:00:00 2001 From: Sam Atkins Date: Fri, 2 Dec 2022 15:11:26 +0000 Subject: [PATCH] LibCore: Optionally pass MSG_NOSIGNAL to socket read/writes When creating a `Core::Stream::Socket`, you can now choose to prevent SIGPIPE signals from firing and terminating your process. This is done by passing MSG_NOSIGNAL to the `System::recv()` or `System::send()` calls when you `read()` or `write()` to that Socket. --- Userland/Libraries/LibCore/Stream.cpp | 12 +++--- Userland/Libraries/LibCore/Stream.h | 56 ++++++++++++++++++++------- 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/Userland/Libraries/LibCore/Stream.cpp b/Userland/Libraries/LibCore/Stream.cpp index b70c9f1518..e023af0531 100644 --- a/Userland/Libraries/LibCore/Stream.cpp +++ b/Userland/Libraries/LibCore/Stream.cpp @@ -435,13 +435,13 @@ ErrorOr PosixSocketHelper::read(Bytes buffer, int flags) return buffer.trim(nread); } -ErrorOr PosixSocketHelper::write(ReadonlyBytes buffer) +ErrorOr PosixSocketHelper::write(ReadonlyBytes buffer, int flags) { if (!is_open()) { return Error::from_errno(ENOTCONN); } - return TRY(System::send(m_fd, buffer.data(), buffer.size(), 0)); + return TRY(System::send(m_fd, buffer.data(), buffer.size(), flags)); } void PosixSocketHelper::close() @@ -574,9 +574,9 @@ ErrorOr> UDPSocket::connect(SocketAddress const& addres return socket; } -ErrorOr> LocalSocket::connect(String const& path) +ErrorOr> LocalSocket::connect(String const& path, PreventSIGPIPE prevent_sigpipe) { - auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) LocalSocket())); + auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) LocalSocket(prevent_sigpipe))); auto fd = TRY(create_fd(SocketDomain::Local, SocketType::Stream)); socket->m_helper.set_fd(fd); @@ -587,13 +587,13 @@ ErrorOr> LocalSocket::connect(String const& path) return socket; } -ErrorOr> LocalSocket::adopt_fd(int fd) +ErrorOr> LocalSocket::adopt_fd(int fd, PreventSIGPIPE prevent_sigpipe) { if (fd < 0) { return Error::from_errno(EBADF); } - auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) LocalSocket())); + auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) LocalSocket(prevent_sigpipe))); socket->m_helper.set_fd(fd); socket->setup_notifier(); return socket; diff --git a/Userland/Libraries/LibCore/Stream.h b/Userland/Libraries/LibCore/Stream.h index 2a28bcae75..b2f05c5f39 100644 --- a/Userland/Libraries/LibCore/Stream.h +++ b/Userland/Libraries/LibCore/Stream.h @@ -106,6 +106,11 @@ public: virtual ErrorOr discard(size_t discarded_bytes) override; }; +enum class PreventSIGPIPE { + No, + Yes, +}; + /// The Socket class is the base class for all concrete BSD-style socket /// classes. Sockets are non-seekable streams which can be read byte-wise. class Socket : public Stream { @@ -149,7 +154,8 @@ protected: Datagram, }; - Socket() + Socket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No) + : m_prevent_sigpipe(prevent_sigpipe == PreventSIGPIPE::Yes) { } @@ -160,6 +166,17 @@ protected: static ErrorOr connect_local(int fd, String const& path); static ErrorOr connect_inet(int fd, SocketAddress const&); + + int default_flags() const + { + int flags = 0; + if (m_prevent_sigpipe) + flags |= MSG_NOSIGNAL; + return flags; + } + +private: + bool m_prevent_sigpipe { false }; }; /// A reusable socket maintains state about being connected in addition to @@ -262,7 +279,9 @@ class PosixSocketHelper { public: template - PosixSocketHelper(Badge) requires(IsBaseOf) { } + PosixSocketHelper(Badge) requires(IsBaseOf) + { + } PosixSocketHelper(PosixSocketHelper&& other) { @@ -280,8 +299,8 @@ public: int fd() const { return m_fd; } void set_fd(int fd) { m_fd = fd; } - ErrorOr read(Bytes, int flags = 0); - ErrorOr write(ReadonlyBytes); + ErrorOr read(Bytes, int flags); + ErrorOr write(ReadonlyBytes, int flags); bool is_eof() const { return !is_open() || m_last_read_was_eof; } bool is_open() const { return m_fd != -1; } @@ -329,8 +348,8 @@ public: virtual bool is_readable() const override { return is_open(); } virtual bool is_writable() const override { return is_open(); } - virtual ErrorOr read(Bytes buffer) override { return m_helper.read(buffer); } - virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.write(buffer); } + virtual ErrorOr read(Bytes buffer) override { return m_helper.read(buffer, default_flags()); } + virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.write(buffer, default_flags()); } virtual bool is_eof() const override { return m_helper.is_eof(); } virtual bool is_open() const override { return m_helper.is_open(); }; virtual void close() override { m_helper.close(); }; @@ -347,7 +366,8 @@ public: virtual ~TCPSocket() override { close(); } private: - TCPSocket() + TCPSocket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No) + : Socket(prevent_sigpipe) { } @@ -400,12 +420,12 @@ public: return Error::from_errno(EMSGSIZE); } - return m_helper.read(buffer); + return m_helper.read(buffer, default_flags()); } virtual bool is_readable() const override { return is_open(); } virtual bool is_writable() const override { return is_open(); } - virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.write(buffer); } + virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.write(buffer, default_flags()); } virtual bool is_eof() const override { return m_helper.is_eof(); } virtual bool is_open() const override { return m_helper.is_open(); } virtual void close() override { m_helper.close(); } @@ -422,7 +442,10 @@ public: virtual ~UDPSocket() override { close(); } private: - UDPSocket() = default; + UDPSocket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No) + : Socket(prevent_sigpipe) + { + } void setup_notifier() { @@ -440,8 +463,8 @@ private: class LocalSocket final : public Socket { public: - static ErrorOr> connect(String const& path); - static ErrorOr> adopt_fd(int fd); + static ErrorOr> connect(String const& path, PreventSIGPIPE = PreventSIGPIPE::No); + static ErrorOr> adopt_fd(int fd, PreventSIGPIPE = PreventSIGPIPE::No); LocalSocket(LocalSocket&& other) : Socket(static_cast(other)) @@ -463,8 +486,8 @@ public: virtual bool is_readable() const override { return is_open(); } virtual bool is_writable() const override { return is_open(); } - virtual ErrorOr read(Bytes buffer) override { return m_helper.read(buffer); } - virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.write(buffer); } + virtual ErrorOr read(Bytes buffer) override { return m_helper.read(buffer, default_flags()); } + virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.write(buffer, default_flags()); } virtual bool is_eof() const override { return m_helper.is_eof(); } virtual bool is_open() const override { return m_helper.is_open(); } virtual void close() override { m_helper.close(); } @@ -495,7 +518,10 @@ public: virtual ~LocalSocket() { close(); } private: - LocalSocket() = default; + LocalSocket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No) + : Socket(prevent_sigpipe) + { + } void setup_notifier() {