From 308773ffda81601ef1a915f03eb4c71b7a28ec6f Mon Sep 17 00:00:00 2001 From: Andreas Kling Date: Tue, 7 Sep 2021 15:05:51 +0200 Subject: [PATCH] Kernel/Net: Add a special SOCKET_TRY() and use it in socket code Sockets remember their last error code in the SO_ERROR field, so we need to take special care to remember this when returning an error. This patch adds a SOCKET_TRY() that works like TRY() but also calls set_so_error() on the failure path. There's probably a lot more code that should be using this, but that's outside the scope of this patch. --- Kernel/Net/IPv4Socket.cpp | 30 +++++++++++------------------- Kernel/Net/LocalSocket.cpp | 26 ++++++-------------------- Kernel/Net/Socket.h | 9 +++++++++ Kernel/Net/TCPSocket.cpp | 3 +-- Kernel/Net/UDPSocket.cpp | 7 ++----- 5 files changed, 29 insertions(+), 46 deletions(-) diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 88cdf95cc3..efec7e4ae5 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -100,9 +100,8 @@ KResult IPv4Socket::bind(Userspace user_address, socklen_t addr if (address_size != sizeof(sockaddr_in)) return set_so_error(EINVAL); - sockaddr_in address; - if (copy_from_user(&address, user_address, sizeof(sockaddr_in)).is_error()) - return set_so_error(EFAULT); + sockaddr_in address {}; + SOCKET_TRY(copy_from_user(&address, user_address, sizeof(sockaddr_in))); if (address.sin_family != AF_INET) return set_so_error(EINVAL); @@ -145,16 +144,14 @@ KResult IPv4Socket::connect(OpenFileDescription& description, Userspace(address.unsafe_userspace_ptr()); - if (copy_from_user(&sa_family_copy, &user_address->sa_family, sizeof(u16)).is_error()) - return set_so_error(EFAULT); + SOCKET_TRY(copy_from_user(&sa_family_copy, &user_address->sa_family, sizeof(u16))); if (sa_family_copy != AF_INET) return set_so_error(EINVAL); if (m_role == Role::Connected) return set_so_error(EISCONN); - sockaddr_in safe_address; - if (copy_from_user(&safe_address, (const sockaddr_in*)user_address, sizeof(sockaddr_in)).is_error()) - return set_so_error(EFAULT); + sockaddr_in safe_address {}; + SOCKET_TRY(copy_from_user(&safe_address, (sockaddr_in const*)user_address, sizeof(sockaddr_in))); m_peer_address = IPv4Address((const u8*)&safe_address.sin_addr.s_addr); if (m_peer_address == IPv4Address { 0, 0, 0, 0 }) @@ -198,9 +195,8 @@ KResultOr IPv4Socket::sendto(OpenFileDescription&, const UserOrKernelBuf return set_so_error(EINVAL); if (addr) { - sockaddr_in ia; - if (copy_from_user(&ia, Userspace(addr.ptr())).is_error()) - return set_so_error(EFAULT); + sockaddr_in ia {}; + SOCKET_TRY(copy_from_user(&ia, Userspace(addr.ptr()))); if (ia.sin_family != AF_INET) { dmesgln("sendto: Bad address family: {} is not AF_INET", ia.sin_family); @@ -358,19 +354,16 @@ KResultOr IPv4Socket::receive_packet_buffered(OpenFileDescription& descr out_addr.sin_port = htons(packet.peer_port); out_addr.sin_family = AF_INET; Userspace dest_addr = addr.ptr(); - if (copy_to_user(dest_addr, &out_addr).is_error()) - return set_so_error(EFAULT); + SOCKET_TRY(copy_to_user(dest_addr, &out_addr)); socklen_t out_length = sizeof(sockaddr_in); VERIFY(addr_length); - if (copy_to_user(addr_length, &out_length).is_error()) - return set_so_error(EFAULT); + SOCKET_TRY(copy_to_user(addr_length, &out_length)); } if (type() == SOCK_RAW) { size_t bytes_written = min(packet.data.value().size(), buffer_length); - if (auto result = buffer.write(packet.data.value().data(), bytes_written); result.is_error()) - return set_so_error(result); + SOCKET_TRY(buffer.write(packet.data.value().data(), bytes_written)); return bytes_written; } @@ -381,8 +374,7 @@ KResultOr IPv4Socket::recvfrom(OpenFileDescription& description, UserOrK { if (user_addr_length) { socklen_t addr_length; - if (copy_from_user(&addr_length, user_addr_length.unsafe_userspace_ptr()).is_error()) - return set_so_error(EFAULT); + SOCKET_TRY(copy_from_user(&addr_length, user_addr_length.unsafe_userspace_ptr())); if (addr_length < sizeof(sockaddr_in)) return set_so_error(EINVAL); } diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp index 9deaa786d2..a6c7fb976e 100644 --- a/Kernel/Net/LocalSocket.cpp +++ b/Kernel/Net/LocalSocket.cpp @@ -113,17 +113,12 @@ KResult LocalSocket::bind(Userspace user_address, socklen_t add return set_so_error(EINVAL); sockaddr_un address = {}; - if (copy_from_user(&address, user_address, sizeof(sockaddr_un)).is_error()) - return set_so_error(EFAULT); + SOCKET_TRY(copy_from_user(&address, user_address, sizeof(sockaddr_un))); if (address.sun_family != AF_LOCAL) return set_so_error(EINVAL); - auto path_kstring_or_error = KString::try_create(StringView { address.sun_path, strnlen(address.sun_path, sizeof(address.sun_path)) }); - if (path_kstring_or_error.is_error()) - return set_so_error(path_kstring_or_error.error()); - auto path = path_kstring_or_error.release_value(); - + auto path = SOCKET_TRY(KString::try_create(StringView { address.sun_path, strnlen(address.sun_path, sizeof(address.sun_path)) })); dbgln_if(LOCAL_SOCKET_DEBUG, "LocalSocket({}) bind({})", this, path); mode_t mode = S_IFSOCK | (m_prebind_mode & 0777); @@ -155,8 +150,7 @@ KResult LocalSocket::connect(OpenFileDescription& description, Userspace(address.unsafe_userspace_ptr()); - if (copy_from_user(&sa_family_copy, &user_address->sa_family, sizeof(u16)).is_error()) - return set_so_error(EFAULT); + SOCKET_TRY(copy_from_user(&sa_family_copy, &user_address->sa_family, sizeof(u16))); if (sa_family_copy != AF_LOCAL) return set_so_error(EINVAL); if (is_connected()) @@ -166,23 +160,15 @@ KResult LocalSocket::connect(OpenFileDescription& description, Userspace(user_address); char safe_address[sizeof(local_address.sun_path) + 1] = { 0 }; - if (copy_from_user(&safe_address[0], &local_address.sun_path[0], sizeof(safe_address) - 1).is_error()) - return set_so_error(EFAULT); + SOCKET_TRY(copy_from_user(&safe_address[0], &local_address.sun_path[0], sizeof(safe_address) - 1)); safe_address[sizeof(safe_address) - 1] = '\0'; - auto path_kstring_or_error = KString::try_create(safe_address); - if (path_kstring_or_error.is_error()) - return set_so_error(path_kstring_or_error.error()); - maybe_path = path_kstring_or_error.release_value(); + maybe_path = SOCKET_TRY(KString::try_create(safe_address)); } auto path = maybe_path.release_nonnull(); dbgln_if(LOCAL_SOCKET_DEBUG, "LocalSocket({}) connect({})", this, *path); - auto description_or_error = VirtualFileSystem::the().open(path->view(), O_RDWR, 0, Process::current().current_directory()); - if (description_or_error.is_error()) - return set_so_error(ECONNREFUSED); - - m_file = move(description_or_error.value()); + m_file = SOCKET_TRY(VirtualFileSystem::the().open(path->view(), O_RDWR, 0, Process::current().current_directory())); VERIFY(m_file->inode()); if (!m_file->inode()->socket()) diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h index f93c14b466..24556ade88 100644 --- a/Kernel/Net/Socket.h +++ b/Kernel/Net/Socket.h @@ -209,4 +209,13 @@ private: RefPtr m_socket; }; +// This is a special variant of TRY() that also updates the socket's SO_ERROR field on error. +#define SOCKET_TRY(expression) \ + ({ \ + auto result = (expression); \ + if (result.is_error()) \ + return set_so_error(result.release_error()); \ + result.release_value(); \ + }) + } diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index 05e6554a9d..9bc5d57fb8 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -164,8 +164,7 @@ KResultOr TCPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, Use size_t payload_size = raw_ipv4_packet.size() - sizeof(IPv4Packet) - tcp_packet.header_size(); dbgln_if(TCP_SOCKET_DEBUG, "payload_size {}, will it fit in {}?", payload_size, buffer_size); VERIFY(buffer_size >= payload_size); - if (auto result = buffer.write(tcp_packet.payload(), payload_size); result.is_error()) - return set_so_error(result); + SOCKET_TRY(buffer.write(tcp_packet.payload(), payload_size)); return payload_size; } diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index 8d2bd44db0..36cdc92bab 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -65,8 +65,7 @@ KResultOr UDPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, Use auto& udp_packet = *static_cast(ipv4_packet.payload()); VERIFY(udp_packet.length() >= sizeof(UDPPacket)); // FIXME: This should be rejected earlier. size_t read_size = min(buffer_size, udp_packet.length() - sizeof(UDPPacket)); - if (auto result = buffer.write(udp_packet.payload(), read_size); result.is_error()) - return set_so_error(result); + SOCKET_TRY(buffer.write(udp_packet.payload(), read_size)); return read_size; } @@ -86,9 +85,7 @@ KResultOr UDPSocket::protocol_send(const UserOrKernelBuffer& data, size_ udp_packet.set_source_port(local_port()); udp_packet.set_destination_port(peer_port()); udp_packet.set_length(udp_buffer_size); - if (auto result = data.read(udp_packet.payload(), data_length); result.is_error()) - return set_so_error(result); - + SOCKET_TRY(data.read(udp_packet.payload(), data_length)); routing_decision.adapter->fill_in_ipv4_header(*packet, local_address(), routing_decision.next_hop, peer_address(), IPv4Protocol::UDP, udp_buffer_size, ttl()); routing_decision.adapter->send_packet(packet->bytes());