From 9216c72bfe47d58ccbee524f3006cfbe5a7d77e0 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Boric Date: Sun, 18 Jul 2021 12:14:43 +0200 Subject: [PATCH] Kernel: Migrate TCP socket tables locking to ProtectedValue Note: TCPSocket::create_client() has a dubious locking process where the sockets by tuple table is first shared lock to check if the socket exists and bail out if it does, then unlocks, then exclusively locks to add the tuple. There could be a race condition where two client creation requests for the same tuple happen at the same time and both cleared the shared lock check. When in doubt, lock exclusively the whole time. --- Kernel/Net/NetworkTask.cpp | 9 +-- Kernel/Net/TCPSocket.cpp | 158 +++++++++++++++++++------------------ Kernel/Net/TCPSocket.h | 8 +- 3 files changed, 90 insertions(+), 85 deletions(-) diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp index 2ea3007335..6adc79953d 100644 --- a/Kernel/Net/NetworkTask.cpp +++ b/Kernel/Net/NetworkTask.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -639,11 +640,9 @@ void retransmit_tcp_packets() // We must keep the sockets alive until after we've unlocked the hash table // in case retransmit_packets() realizes that it wants to close the socket. NonnullRefPtrVector sockets; - { - MutexLocker locker(TCPSocket::sockets_for_retransmit().lock(), LockMode::Shared); - for (auto& socket : TCPSocket::sockets_for_retransmit().resource()) - sockets.append(*socket); - } + TCPSocket::sockets_for_retransmit().for_each_shared([&](const auto& socket) { + sockets.append(*socket); + }); for (auto& socket : sockets) { MutexLocker socket_locker(socket.lock()); diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index e555899169..faf3352a41 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include @@ -24,9 +24,9 @@ namespace Kernel { void TCPSocket::for_each(Function callback) { - MutexLocker locker(sockets_by_tuple().lock(), Mutex::Mode::Shared); - for (auto& it : sockets_by_tuple().resource()) + sockets_by_tuple().for_each_shared([&](const auto& it) { callback(*it.value); + }); } void TCPSocket::set_state(State new_state) @@ -42,8 +42,9 @@ void TCPSocket::set_state(State new_state) m_role = Role::Connected; if (new_state == State::Closed) { - MutexLocker locker(closing_sockets().lock()); - closing_sockets().resource().remove(tuple()); + closing_sockets().with_exclusive([&](auto& table) { + table.remove(tuple()); + }); if (m_originator) release_to_originator(); @@ -53,71 +54,68 @@ void TCPSocket::set_state(State new_state) evaluate_block_conditions(); } -static AK::Singleton>>> s_socket_closing; +static AK::Singleton>>> s_socket_closing; -Lockable>>& TCPSocket::closing_sockets() +ProtectedValue>>& TCPSocket::closing_sockets() { return *s_socket_closing; } -static AK::Singleton>> s_socket_tuples; +static AK::Singleton>> s_socket_tuples; -Lockable>& TCPSocket::sockets_by_tuple() +ProtectedValue>& TCPSocket::sockets_by_tuple() { return *s_socket_tuples; } RefPtr TCPSocket::from_tuple(const IPv4SocketTuple& tuple) { - MutexLocker locker(sockets_by_tuple().lock(), Mutex::Mode::Shared); + return sockets_by_tuple().with_shared([&](const auto& table) -> RefPtr { + auto exact_match = table.get(tuple); + if (exact_match.has_value()) + return { *exact_match.value() }; - auto exact_match = sockets_by_tuple().resource().get(tuple); - if (exact_match.has_value()) - return { *exact_match.value() }; + auto address_tuple = IPv4SocketTuple(tuple.local_address(), tuple.local_port(), IPv4Address(), 0); + auto address_match = table.get(address_tuple); + if (address_match.has_value()) + return { *address_match.value() }; - auto address_tuple = IPv4SocketTuple(tuple.local_address(), tuple.local_port(), IPv4Address(), 0); - auto address_match = sockets_by_tuple().resource().get(address_tuple); - if (address_match.has_value()) - return { *address_match.value() }; + auto wildcard_tuple = IPv4SocketTuple(IPv4Address(), tuple.local_port(), IPv4Address(), 0); + auto wildcard_match = table.get(wildcard_tuple); + if (wildcard_match.has_value()) + return { *wildcard_match.value() }; - auto wildcard_tuple = IPv4SocketTuple(IPv4Address(), tuple.local_port(), IPv4Address(), 0); - auto wildcard_match = sockets_by_tuple().resource().get(wildcard_tuple); - if (wildcard_match.has_value()) - return { *wildcard_match.value() }; - - return {}; + return {}; + }); } RefPtr TCPSocket::create_client(const IPv4Address& new_local_address, u16 new_local_port, const IPv4Address& new_peer_address, u16 new_peer_port) { auto tuple = IPv4SocketTuple(new_local_address, new_local_port, new_peer_address, new_peer_port); - - { - MutexLocker locker(sockets_by_tuple().lock(), Mutex::Mode::Shared); - if (sockets_by_tuple().resource().contains(tuple)) + return sockets_by_tuple().with_exclusive([&](auto& table) -> RefPtr { + if (table.contains(tuple)) return {}; - } - auto receive_buffer = create_receive_buffer(); - if (!receive_buffer) - return {}; - auto result = TCPSocket::create(protocol(), receive_buffer.release_nonnull()); - if (result.is_error()) - return {}; + auto receive_buffer = create_receive_buffer(); + if (!receive_buffer) + return {}; + auto result = TCPSocket::create(protocol(), receive_buffer.release_nonnull()); + if (result.is_error()) + return {}; - auto client = result.release_value(); - client->set_setup_state(SetupState::InProgress); - client->set_local_address(new_local_address); - client->set_local_port(new_local_port); - client->set_peer_address(new_peer_address); - client->set_peer_port(new_peer_port); - client->set_direction(Direction::Incoming); - client->set_originator(*this); + auto client = result.release_value(); + client->set_setup_state(SetupState::InProgress); + client->set_local_address(new_local_address); + client->set_local_port(new_local_port); + client->set_peer_address(new_peer_address); + client->set_peer_port(new_peer_port); + client->set_direction(Direction::Incoming); + client->set_originator(*this); - MutexLocker locker(sockets_by_tuple().lock()); - m_pending_release_for_accept.set(tuple, client); - sockets_by_tuple().resource().set(tuple, client); + m_pending_release_for_accept.set(tuple, client); + table.set(tuple, client); - return client; + return { move(client) }; + }); } void TCPSocket::release_to_originator() @@ -143,8 +141,9 @@ TCPSocket::TCPSocket(int protocol, NonnullOwnPtr receive_buffer, O TCPSocket::~TCPSocket() { - MutexLocker locker(sockets_by_tuple().lock()); - sockets_by_tuple().resource().remove(tuple()); + sockets_by_tuple().with_exclusive([&](auto& table) { + table.remove(tuple()); + }); dequeue_for_retransmit(); @@ -378,10 +377,14 @@ KResult TCPSocket::protocol_bind() KResult TCPSocket::protocol_listen(bool did_allocate_port) { if (!did_allocate_port) { - MutexLocker socket_locker(sockets_by_tuple().lock()); - if (sockets_by_tuple().resource().contains(tuple())) + bool ok = sockets_by_tuple().with_exclusive([&](auto& table) -> bool { + if (table.contains(tuple())) + return false; + table.set(tuple(), this); + return true; + }); + if (!ok) return EADDRINUSE; - sockets_by_tuple().resource().set(tuple(), this); } set_direction(Direction::Passive); @@ -443,23 +446,24 @@ KResultOr TCPSocket::protocol_allocate_local_port() constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; - MutexLocker locker(sockets_by_tuple().lock()); - for (u16 port = first_scan_port;;) { - IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); + return sockets_by_tuple().with_exclusive([&](auto& table) -> KResultOr { + for (u16 port = first_scan_port;;) { + IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); - auto it = sockets_by_tuple().resource().find(proposed_tuple); - if (it == sockets_by_tuple().resource().end()) { - set_local_port(port); - sockets_by_tuple().resource().set(proposed_tuple, this); - return port; + auto it = table.find(proposed_tuple); + if (it == table.end()) { + set_local_port(port); + table.set(proposed_tuple, this); + return port; + } + ++port; + if (port > last_ephemeral_port) + port = first_ephemeral_port; + if (port == first_scan_port) + break; } - ++port; - if (port > last_ephemeral_port) - port = first_ephemeral_port; - if (port == first_scan_port) - break; - } - return EADDRINUSE; + return EADDRINUSE; + }); } bool TCPSocket::protocol_is_disconnected() const @@ -499,30 +503,32 @@ KResult TCPSocket::close() set_state(State::LastAck); } - if (state() != State::Closed && state() != State::Listen) { - MutexLocker locker(closing_sockets().lock()); - closing_sockets().resource().set(tuple(), *this); - } + if (state() != State::Closed && state() != State::Listen) + closing_sockets().with_exclusive([&](auto& table) { + table.set(tuple(), *this); + }); return result; } -static AK::Singleton>> s_sockets_for_retransmit; +static AK::Singleton>> s_sockets_for_retransmit; -Lockable>& TCPSocket::sockets_for_retransmit() +ProtectedValue>& TCPSocket::sockets_for_retransmit() { return *s_sockets_for_retransmit; } void TCPSocket::enqueue_for_retransmit() { - MutexLocker locker(sockets_for_retransmit().lock()); - sockets_for_retransmit().resource().set(this); + sockets_for_retransmit().with_exclusive([&](auto& table) { + table.set(this); + }); } void TCPSocket::dequeue_for_retransmit() { - MutexLocker locker(sockets_for_retransmit().lock()); - sockets_for_retransmit().resource().remove(this); + sockets_for_retransmit().with_exclusive([&](auto& table) { + table.remove(this); + }); } void TCPSocket::retransmit_packets() diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index 9da7b83ca9..9422f78598 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include namespace Kernel { @@ -142,10 +142,10 @@ public: bool should_delay_next_ack() const; - static Lockable>& sockets_by_tuple(); + static ProtectedValue>& sockets_by_tuple(); static RefPtr from_tuple(const IPv4SocketTuple& tuple); - static Lockable>>& closing_sockets(); + static ProtectedValue>>& closing_sockets(); RefPtr create_client(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port); void set_originator(TCPSocket& originator) { m_originator = originator; } @@ -153,7 +153,7 @@ public: void release_to_originator(); void release_for_accept(RefPtr); - static Lockable>& sockets_for_retransmit(); + static ProtectedValue>& sockets_for_retransmit(); void retransmit_packets(); virtual KResult close() override;