diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 1a1aa8750d..00acbd8eb2 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -246,7 +246,7 @@ ErrorOr IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer cons return nsent_or_error; } -ErrorOr IPv4Socket::receive_byte_buffered(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace, Userspace) +ErrorOr IPv4Socket::receive_byte_buffered(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace, Userspace, bool blocking) { MutexLocker locker(mutex()); @@ -255,7 +255,7 @@ ErrorOr IPv4Socket::receive_byte_buffered(OpenFileDescription& descripti if (m_receive_buffer->is_empty()) { if (protocol_is_disconnected()) return 0; - if (!description.is_blocking()) + if (!blocking) return set_so_error(EAGAIN); locker.unlock(); @@ -285,7 +285,7 @@ ErrorOr IPv4Socket::receive_byte_buffered(OpenFileDescription& descripti return nreceived_or_error; } -ErrorOr IPv4Socket::receive_packet_buffered(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace addr, Userspace addr_length, Time& packet_timestamp) +ErrorOr IPv4Socket::receive_packet_buffered(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace addr, Userspace addr_length, Time& packet_timestamp, bool blocking) { MutexLocker locker(mutex()); ReceivedPacket taken_packet; @@ -296,7 +296,7 @@ ErrorOr IPv4Socket::receive_packet_buffered(OpenFileDescription& descrip // But if so, we still need to deliver at least one EOF read to userspace.. right? if (protocol_is_disconnected()) return 0; - if (!description.is_blocking()) + if (!blocking) return set_so_error(EAGAIN); } @@ -380,7 +380,7 @@ ErrorOr IPv4Socket::receive_packet_buffered(OpenFileDescription& descrip return protocol_receive(packet->data->bytes(), buffer, buffer_length, flags); } -ErrorOr IPv4Socket::recvfrom(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace user_addr, Userspace user_addr_length, Time& packet_timestamp) +ErrorOr IPv4Socket::recvfrom(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace user_addr, Userspace user_addr_length, Time& packet_timestamp, bool blocking) { if (user_addr_length) { socklen_t addr_length; @@ -398,9 +398,9 @@ ErrorOr IPv4Socket::recvfrom(OpenFileDescription& description, UserOrKer ErrorOr nreceived = 0; if (buffer_mode() == BufferMode::Bytes) - nreceived = receive_byte_buffered(description, offset_buffer, offset_buffer_length, flags, user_addr, user_addr_length); + nreceived = receive_byte_buffered(description, offset_buffer, offset_buffer_length, flags, user_addr, user_addr_length, blocking); else - nreceived = receive_packet_buffered(description, offset_buffer, offset_buffer_length, flags, user_addr, user_addr_length, packet_timestamp); + nreceived = receive_packet_buffered(description, offset_buffer, offset_buffer_length, flags, user_addr, user_addr_length, packet_timestamp, blocking); if (nreceived.is_error()) total_nreceived = nreceived; diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h index fa3a7c9e59..fb4784d3a4 100644 --- a/Kernel/Net/IPv4Socket.h +++ b/Kernel/Net/IPv4Socket.h @@ -40,7 +40,7 @@ public: virtual bool can_read(OpenFileDescription const&, u64) const override; virtual bool can_write(OpenFileDescription const&, u64) const override; virtual ErrorOr sendto(OpenFileDescription&, UserOrKernelBuffer const&, size_t, int, Userspace, socklen_t) override; - virtual ErrorOr recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace, Userspace, Time&) override; + virtual ErrorOr recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace, Userspace, Time&, bool blocking) override; virtual ErrorOr setsockopt(int level, int option, Userspace, socklen_t) override; virtual ErrorOr getsockopt(OpenFileDescription&, int level, int option, Userspace, Userspace) override; @@ -98,8 +98,8 @@ protected: private: virtual bool is_ipv4() const override { return true; } - ErrorOr receive_byte_buffered(OpenFileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace, Userspace); - ErrorOr receive_packet_buffered(OpenFileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace, Userspace, Time&); + ErrorOr receive_byte_buffered(OpenFileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace, Userspace, bool blocking); + ErrorOr receive_packet_buffered(OpenFileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace, Userspace, Time&, bool blocking); void set_can_read(bool); diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp index d5820aad41..0ad680b654 100644 --- a/Kernel/Net/LocalSocket.cpp +++ b/Kernel/Net/LocalSocket.cpp @@ -334,12 +334,12 @@ DoubleBuffer* LocalSocket::send_buffer_for(OpenFileDescription& description) return nullptr; } -ErrorOr LocalSocket::recvfrom(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_size, int, Userspace, Userspace, Time&) +ErrorOr LocalSocket::recvfrom(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_size, int, Userspace, Userspace, Time&, bool blocking) { auto* socket_buffer = receive_buffer_for(description); if (!socket_buffer) return set_so_error(EINVAL); - if (!description.is_blocking()) { + if (!blocking) { if (socket_buffer->is_empty()) { if (!has_attached_peer(description)) return 0; diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h index 8174b235fd..91722d72cd 100644 --- a/Kernel/Net/LocalSocket.h +++ b/Kernel/Net/LocalSocket.h @@ -46,7 +46,7 @@ public: virtual bool can_read(OpenFileDescription const&, u64) const override; virtual bool can_write(OpenFileDescription const&, u64) const override; virtual ErrorOr sendto(OpenFileDescription&, UserOrKernelBuffer const&, size_t, int, Userspace, socklen_t) override; - virtual ErrorOr recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace, Userspace, Time&) override; + virtual ErrorOr recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace, Userspace, Time&, bool blocking) override; virtual ErrorOr getsockopt(OpenFileDescription&, int level, int option, Userspace, Userspace) override; virtual ErrorOr ioctl(OpenFileDescription&, unsigned request, Userspace arg) override; virtual ErrorOr chown(Credentials const&, OpenFileDescription&, UserID, GroupID) override; diff --git a/Kernel/Net/Socket.cpp b/Kernel/Net/Socket.cpp index b38be8a6e0..2832603db3 100644 --- a/Kernel/Net/Socket.cpp +++ b/Kernel/Net/Socket.cpp @@ -242,7 +242,7 @@ ErrorOr Socket::read(OpenFileDescription& description, u64, UserOrKernel if (is_shut_down_for_reading()) return 0; Time t {}; - return recvfrom(description, buffer, size, 0, {}, 0, t); + return recvfrom(description, buffer, size, 0, {}, 0, t, description.is_blocking()); } ErrorOr Socket::write(OpenFileDescription& description, u64, UserOrKernelBuffer const& data, size_t size) diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h index 4e0de6ec11..63f92aa270 100644 --- a/Kernel/Net/Socket.h +++ b/Kernel/Net/Socket.h @@ -80,7 +80,7 @@ public: virtual bool is_local() const { return false; } virtual bool is_ipv4() const { return false; } virtual ErrorOr sendto(OpenFileDescription&, UserOrKernelBuffer const&, size_t, int flags, Userspace, socklen_t) = 0; - virtual ErrorOr recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace, Userspace, Time&) = 0; + virtual ErrorOr recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace, Userspace, Time&, bool blocking) = 0; virtual ErrorOr setsockopt(int level, int option, Userspace, socklen_t); virtual ErrorOr getsockopt(OpenFileDescription&, int level, int option, Userspace, Userspace); diff --git a/Kernel/Syscalls/socket.cpp b/Kernel/Syscalls/socket.cpp index 70eded6dc4..d7e292011f 100644 --- a/Kernel/Syscalls/socket.cpp +++ b/Kernel/Syscalls/socket.cpp @@ -241,15 +241,10 @@ ErrorOr Process::sys$recvmsg(int sockfd, Userspace user if (socket.is_shut_down_for_reading()) return 0; - bool original_blocking = description->is_blocking(); - if (flags & MSG_DONTWAIT) - description->set_blocking(false); - auto data_buffer = TRY(UserOrKernelBuffer::for_user_buffer((u8*)iovs[0].iov_base, iovs[0].iov_len)); Time timestamp {}; - auto result = socket.recvfrom(*description, data_buffer, iovs[0].iov_len, flags, user_addr, user_addr_length, timestamp); - if (flags & MSG_DONTWAIT) - description->set_blocking(original_blocking); + bool blocking = (flags & MSG_DONTWAIT) ? false : description->is_blocking(); + auto result = socket.recvfrom(*description, data_buffer, iovs[0].iov_len, flags, user_addr, user_addr_length, timestamp, blocking); if (result.is_error()) return result.release_error();