diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 610274c295..728b4b1c37 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -36,8 +36,12 @@ Lockable>& IPv4Socket::all_sockets() KResultOr> IPv4Socket::create(int type, int protocol) { - if (type == SOCK_STREAM) - return TCPSocket::create(protocol); + if (type == SOCK_STREAM) { + auto tcp_socket = TCPSocket::create(protocol); + if (tcp_socket.is_error()) + return tcp_socket.error(); + return tcp_socket.release_value(); + } if (type == SOCK_DGRAM) return UDPSocket::create(protocol); if (type == SOCK_RAW) diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index 035e76d480..353a59fbc1 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -98,8 +98,11 @@ RefPtr TCPSocket::create_client(const IPv4Address& new_local_address, if (sockets_by_tuple().resource().contains(tuple)) return {}; - auto client = TCPSocket::create(protocol()); + auto result = TCPSocket::create(protocol()); + if (result.is_error()) + return {}; + auto client = result.release_value(); client->set_setup_state(SetupState::InProgress); client->set_local_address(new_local_address); client->set_local_port(new_local_port); @@ -142,9 +145,12 @@ TCPSocket::~TCPSocket() dbgln_if(TCP_SOCKET_DEBUG, "~TCPSocket in state {}", to_string(state())); } -NonnullRefPtr TCPSocket::create(int protocol) +KResultOr> TCPSocket::create(int protocol) { - return adopt_ref(*new TCPSocket(protocol)); + auto socket = adopt_ref_if_nonnull(new TCPSocket(protocol)); + if (socket) + return socket.release_nonnull(); + return ENOMEM; } KResultOr TCPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, [[maybe_unused]] int flags) diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index 28621a98f0..ac3489ef96 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace Kernel { @@ -17,7 +18,7 @@ namespace Kernel { class TCPSocket final : public IPv4Socket { public: static void for_each(Function); - static NonnullRefPtr create(int protocol); + static KResultOr> create(int protocol); virtual ~TCPSocket() override; enum class Direction {