diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index ccb0df0b34..fbf69fae05 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -143,6 +143,7 @@ ErrorOr> TCPSocket::try_create_client(IPv4Address const client->set_originator(*this); m_pending_release_for_accept.set(tuple, client); + client->m_registered_socket_tuple = tuple; table.set(tuple, client); return { move(client) }; @@ -504,6 +505,7 @@ ErrorOr TCPSocket::protocol_bind() auto it = table.find(proposed_tuple); if (it == table.end()) { set_local_port(port); + m_registered_socket_tuple = proposed_tuple; table.set(proposed_tuple, this); dbgln_if(TCP_SOCKET_DEBUG, "...allocated port {}, tuple {}", port, proposed_tuple.to_string()); return {}; @@ -521,7 +523,9 @@ ErrorOr TCPSocket::protocol_bind() bool ok = sockets_by_tuple().with_exclusive([&](auto& table) -> bool { if (table.contains(tuple())) return false; - table.set(tuple(), this); + auto socket_tuple = tuple(); + m_registered_socket_tuple = socket_tuple; + table.set(socket_tuple, this); return true; }); if (!ok) @@ -549,6 +553,21 @@ ErrorOr TCPSocket::protocol_connect(OpenFileDescription& description) set_local_address(routing_decision.adapter->ipv4_address()); TRY(ensure_bound()); + if (m_registered_socket_tuple.has_value() && m_registered_socket_tuple != tuple()) { + // If the socket was manually bound (using bind(2)) instead of implicitly using connect, + // it will already be registered in the TCPSocket sockets_by_tuple table, under the previous + // socket tuple. We replace the entry in the table to ensure it is also properly removed on + // socket deletion, to prevent a dangling reference. + TRY(sockets_by_tuple().with_exclusive([this](auto& table) -> ErrorOr { + auto removed = table.remove(*m_registered_socket_tuple); + VERIFY(removed); + if (table.contains(tuple())) + return set_so_error(EADDRINUSE); + table.set(tuple(), this); + return {}; + })); + m_registered_socket_tuple = tuple(); + } m_sequence_number = get_good_random(); m_ack_number = 0; diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index 51e630c9fc..466e23da34 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -234,6 +234,8 @@ private: IntrusiveListNode m_retransmit_list_node; + Optional m_registered_socket_tuple; + public: using RetransmitList = IntrusiveList<&TCPSocket::m_retransmit_list_node>; static MutexProtected& sockets_for_retransmit(); diff --git a/Tests/Kernel/TestTCPSocket.cpp b/Tests/Kernel/TestTCPSocket.cpp index 34e3f2a2d1..fcabcd1504 100644 --- a/Tests/Kernel/TestTCPSocket.cpp +++ b/Tests/Kernel/TestTCPSocket.cpp @@ -4,6 +4,8 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include +#include #include #include #include @@ -80,3 +82,49 @@ TEST_CASE(tcp_sendto) rc = pthread_join(server, nullptr); EXPECT_EQ(rc, 0); } + +TEST_CASE(tcp_bind_connect) +{ + pthread_t server = start_tcp_server(); + + int client_fd = socket(AF_INET, SOCK_STREAM, 0); + EXPECT(client_fd >= 0); + + sockaddr_in sin {}; + sin.sin_family = AF_INET; + sin.sin_port = htons(port - 1); + sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + int rc = bind(client_fd, (sockaddr*)(&sin), sizeof(sin)); + EXPECT_EQ(rc, 0); + + sockaddr_in dst {}; + dst.sin_family = AF_INET; + dst.sin_port = htons(port); + dst.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + rc = connect(client_fd, (sockaddr*)(&dst), sizeof(dst)); + EXPECT_EQ(rc, 0); + + u8 data = 'A'; + int nwritten = send(client_fd, &data, sizeof(data), 0); + EXPECT_EQ(nwritten, 1); + + rc = close(client_fd); + EXPECT_EQ(rc, 0); + + rc = pthread_join(server, nullptr); + EXPECT_EQ(rc, 0); + + // Hacky check to make sure there are no registered TCP sockets, if the sockets were closed properly, there should + // be none left, but if the early-bind caused a desync in sockets_by_tuple a UAF'd socket will be left in there. + // NOTE: We have to loop since the TimedWait stage during socket close means the socket might not close immediately + // after our close(2) call. This also means that on failure we will loop here forever. + while (true) { + auto file = MUST(Core::File::open("/sys/kernel/net/tcp"sv, Core::File::OpenMode::Read)); + auto file_contents = MUST(file->read_until_eof()); + auto json = MUST(JsonValue::from_string(file_contents)); + EXPECT(json.is_array()); + if (json.as_array().size() == 0) + return; + sched_yield(); + } +}