1
Fork 0
mirror of https://github.com/RGBCube/serenity synced 2025-07-25 23:17:45 +00:00

Kernel: Migrate TCP socket tables locking to ProtectedValue

Note: TCPSocket::create_client() has a dubious locking process where
the sockets by tuple table is first shared lock to check if the socket
exists and bail out if it does, then unlocks, then exclusively locks to
add the tuple. There could be a race condition where two client
creation requests for the same tuple happen at the same time and both
cleared the shared lock check. When in doubt, lock exclusively the
whole time.
This commit is contained in:
Jean-Baptiste Boric 2021-07-18 12:14:43 +02:00 committed by Andreas Kling
parent 583abc27d8
commit 9216c72bfe
3 changed files with 90 additions and 85 deletions

View file

@ -6,6 +6,7 @@
#include <Kernel/Debug.h> #include <Kernel/Debug.h>
#include <Kernel/Locking/Mutex.h> #include <Kernel/Locking/Mutex.h>
#include <Kernel/Locking/ProtectedValue.h>
#include <Kernel/Net/ARP.h> #include <Kernel/Net/ARP.h>
#include <Kernel/Net/EtherType.h> #include <Kernel/Net/EtherType.h>
#include <Kernel/Net/EthernetFrameHeader.h> #include <Kernel/Net/EthernetFrameHeader.h>
@ -639,11 +640,9 @@ void retransmit_tcp_packets()
// We must keep the sockets alive until after we've unlocked the hash table // We must keep the sockets alive until after we've unlocked the hash table
// in case retransmit_packets() realizes that it wants to close the socket. // in case retransmit_packets() realizes that it wants to close the socket.
NonnullRefPtrVector<TCPSocket, 16> sockets; NonnullRefPtrVector<TCPSocket, 16> sockets;
{ TCPSocket::sockets_for_retransmit().for_each_shared([&](const auto& socket) {
MutexLocker locker(TCPSocket::sockets_for_retransmit().lock(), LockMode::Shared); sockets.append(*socket);
for (auto& socket : TCPSocket::sockets_for_retransmit().resource()) });
sockets.append(*socket);
}
for (auto& socket : sockets) { for (auto& socket : sockets) {
MutexLocker socket_locker(socket.lock()); MutexLocker socket_locker(socket.lock());

View file

@ -9,7 +9,7 @@
#include <Kernel/Debug.h> #include <Kernel/Debug.h>
#include <Kernel/Devices/RandomDevice.h> #include <Kernel/Devices/RandomDevice.h>
#include <Kernel/FileSystem/FileDescription.h> #include <Kernel/FileSystem/FileDescription.h>
#include <Kernel/Locking/Lockable.h> #include <Kernel/Locking/ProtectedValue.h>
#include <Kernel/Net/EthernetFrameHeader.h> #include <Kernel/Net/EthernetFrameHeader.h>
#include <Kernel/Net/IPv4.h> #include <Kernel/Net/IPv4.h>
#include <Kernel/Net/NetworkAdapter.h> #include <Kernel/Net/NetworkAdapter.h>
@ -24,9 +24,9 @@ namespace Kernel {
void TCPSocket::for_each(Function<void(const TCPSocket&)> callback) void TCPSocket::for_each(Function<void(const TCPSocket&)> callback)
{ {
MutexLocker locker(sockets_by_tuple().lock(), Mutex::Mode::Shared); sockets_by_tuple().for_each_shared([&](const auto& it) {
for (auto& it : sockets_by_tuple().resource())
callback(*it.value); callback(*it.value);
});
} }
void TCPSocket::set_state(State new_state) void TCPSocket::set_state(State new_state)
@ -42,8 +42,9 @@ void TCPSocket::set_state(State new_state)
m_role = Role::Connected; m_role = Role::Connected;
if (new_state == State::Closed) { if (new_state == State::Closed) {
MutexLocker locker(closing_sockets().lock()); closing_sockets().with_exclusive([&](auto& table) {
closing_sockets().resource().remove(tuple()); table.remove(tuple());
});
if (m_originator) if (m_originator)
release_to_originator(); release_to_originator();
@ -53,71 +54,68 @@ void TCPSocket::set_state(State new_state)
evaluate_block_conditions(); evaluate_block_conditions();
} }
static AK::Singleton<Lockable<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>> s_socket_closing; static AK::Singleton<ProtectedValue<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>> s_socket_closing;
Lockable<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>& TCPSocket::closing_sockets() ProtectedValue<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>& TCPSocket::closing_sockets()
{ {
return *s_socket_closing; return *s_socket_closing;
} }
static AK::Singleton<Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>> s_socket_tuples; static AK::Singleton<ProtectedValue<HashMap<IPv4SocketTuple, TCPSocket*>>> s_socket_tuples;
Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& TCPSocket::sockets_by_tuple() ProtectedValue<HashMap<IPv4SocketTuple, TCPSocket*>>& TCPSocket::sockets_by_tuple()
{ {
return *s_socket_tuples; return *s_socket_tuples;
} }
RefPtr<TCPSocket> TCPSocket::from_tuple(const IPv4SocketTuple& tuple) RefPtr<TCPSocket> TCPSocket::from_tuple(const IPv4SocketTuple& tuple)
{ {
MutexLocker locker(sockets_by_tuple().lock(), Mutex::Mode::Shared); return sockets_by_tuple().with_shared([&](const auto& table) -> RefPtr<TCPSocket> {
auto exact_match = table.get(tuple);
if (exact_match.has_value())
return { *exact_match.value() };
auto exact_match = sockets_by_tuple().resource().get(tuple); auto address_tuple = IPv4SocketTuple(tuple.local_address(), tuple.local_port(), IPv4Address(), 0);
if (exact_match.has_value()) auto address_match = table.get(address_tuple);
return { *exact_match.value() }; if (address_match.has_value())
return { *address_match.value() };
auto address_tuple = IPv4SocketTuple(tuple.local_address(), tuple.local_port(), IPv4Address(), 0); auto wildcard_tuple = IPv4SocketTuple(IPv4Address(), tuple.local_port(), IPv4Address(), 0);
auto address_match = sockets_by_tuple().resource().get(address_tuple); auto wildcard_match = table.get(wildcard_tuple);
if (address_match.has_value()) if (wildcard_match.has_value())
return { *address_match.value() }; return { *wildcard_match.value() };
auto wildcard_tuple = IPv4SocketTuple(IPv4Address(), tuple.local_port(), IPv4Address(), 0); return {};
auto wildcard_match = sockets_by_tuple().resource().get(wildcard_tuple); });
if (wildcard_match.has_value())
return { *wildcard_match.value() };
return {};
} }
RefPtr<TCPSocket> TCPSocket::create_client(const IPv4Address& new_local_address, u16 new_local_port, const IPv4Address& new_peer_address, u16 new_peer_port) RefPtr<TCPSocket> TCPSocket::create_client(const IPv4Address& new_local_address, u16 new_local_port, const IPv4Address& new_peer_address, u16 new_peer_port)
{ {
auto tuple = IPv4SocketTuple(new_local_address, new_local_port, new_peer_address, new_peer_port); auto tuple = IPv4SocketTuple(new_local_address, new_local_port, new_peer_address, new_peer_port);
return sockets_by_tuple().with_exclusive([&](auto& table) -> RefPtr<TCPSocket> {
{ if (table.contains(tuple))
MutexLocker locker(sockets_by_tuple().lock(), Mutex::Mode::Shared);
if (sockets_by_tuple().resource().contains(tuple))
return {}; return {};
}
auto receive_buffer = create_receive_buffer(); auto receive_buffer = create_receive_buffer();
if (!receive_buffer) if (!receive_buffer)
return {}; return {};
auto result = TCPSocket::create(protocol(), receive_buffer.release_nonnull()); auto result = TCPSocket::create(protocol(), receive_buffer.release_nonnull());
if (result.is_error()) if (result.is_error())
return {}; return {};
auto client = result.release_value(); auto client = result.release_value();
client->set_setup_state(SetupState::InProgress); client->set_setup_state(SetupState::InProgress);
client->set_local_address(new_local_address); client->set_local_address(new_local_address);
client->set_local_port(new_local_port); client->set_local_port(new_local_port);
client->set_peer_address(new_peer_address); client->set_peer_address(new_peer_address);
client->set_peer_port(new_peer_port); client->set_peer_port(new_peer_port);
client->set_direction(Direction::Incoming); client->set_direction(Direction::Incoming);
client->set_originator(*this); client->set_originator(*this);
MutexLocker locker(sockets_by_tuple().lock()); m_pending_release_for_accept.set(tuple, client);
m_pending_release_for_accept.set(tuple, client); table.set(tuple, client);
sockets_by_tuple().resource().set(tuple, client);
return client; return { move(client) };
});
} }
void TCPSocket::release_to_originator() void TCPSocket::release_to_originator()
@ -143,8 +141,9 @@ TCPSocket::TCPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer, O
TCPSocket::~TCPSocket() TCPSocket::~TCPSocket()
{ {
MutexLocker locker(sockets_by_tuple().lock()); sockets_by_tuple().with_exclusive([&](auto& table) {
sockets_by_tuple().resource().remove(tuple()); table.remove(tuple());
});
dequeue_for_retransmit(); dequeue_for_retransmit();
@ -378,10 +377,14 @@ KResult TCPSocket::protocol_bind()
KResult TCPSocket::protocol_listen(bool did_allocate_port) KResult TCPSocket::protocol_listen(bool did_allocate_port)
{ {
if (!did_allocate_port) { if (!did_allocate_port) {
MutexLocker socket_locker(sockets_by_tuple().lock()); bool ok = sockets_by_tuple().with_exclusive([&](auto& table) -> bool {
if (sockets_by_tuple().resource().contains(tuple())) if (table.contains(tuple()))
return false;
table.set(tuple(), this);
return true;
});
if (!ok)
return EADDRINUSE; return EADDRINUSE;
sockets_by_tuple().resource().set(tuple(), this);
} }
set_direction(Direction::Passive); set_direction(Direction::Passive);
@ -443,23 +446,24 @@ KResultOr<u16> TCPSocket::protocol_allocate_local_port()
constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size; u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size;
MutexLocker locker(sockets_by_tuple().lock()); return sockets_by_tuple().with_exclusive([&](auto& table) -> KResultOr<u16> {
for (u16 port = first_scan_port;;) { for (u16 port = first_scan_port;;) {
IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port());
auto it = sockets_by_tuple().resource().find(proposed_tuple); auto it = table.find(proposed_tuple);
if (it == sockets_by_tuple().resource().end()) { if (it == table.end()) {
set_local_port(port); set_local_port(port);
sockets_by_tuple().resource().set(proposed_tuple, this); table.set(proposed_tuple, this);
return port; return port;
}
++port;
if (port > last_ephemeral_port)
port = first_ephemeral_port;
if (port == first_scan_port)
break;
} }
++port; return EADDRINUSE;
if (port > last_ephemeral_port) });
port = first_ephemeral_port;
if (port == first_scan_port)
break;
}
return EADDRINUSE;
} }
bool TCPSocket::protocol_is_disconnected() const bool TCPSocket::protocol_is_disconnected() const
@ -499,30 +503,32 @@ KResult TCPSocket::close()
set_state(State::LastAck); set_state(State::LastAck);
} }
if (state() != State::Closed && state() != State::Listen) { if (state() != State::Closed && state() != State::Listen)
MutexLocker locker(closing_sockets().lock()); closing_sockets().with_exclusive([&](auto& table) {
closing_sockets().resource().set(tuple(), *this); table.set(tuple(), *this);
} });
return result; return result;
} }
static AK::Singleton<Lockable<HashTable<TCPSocket*>>> s_sockets_for_retransmit; static AK::Singleton<ProtectedValue<HashTable<TCPSocket*>>> s_sockets_for_retransmit;
Lockable<HashTable<TCPSocket*>>& TCPSocket::sockets_for_retransmit() ProtectedValue<HashTable<TCPSocket*>>& TCPSocket::sockets_for_retransmit()
{ {
return *s_sockets_for_retransmit; return *s_sockets_for_retransmit;
} }
void TCPSocket::enqueue_for_retransmit() void TCPSocket::enqueue_for_retransmit()
{ {
MutexLocker locker(sockets_for_retransmit().lock()); sockets_for_retransmit().with_exclusive([&](auto& table) {
sockets_for_retransmit().resource().set(this); table.set(this);
});
} }
void TCPSocket::dequeue_for_retransmit() void TCPSocket::dequeue_for_retransmit()
{ {
MutexLocker locker(sockets_for_retransmit().lock()); sockets_for_retransmit().with_exclusive([&](auto& table) {
sockets_for_retransmit().resource().remove(this); table.remove(this);
});
} }
void TCPSocket::retransmit_packets() void TCPSocket::retransmit_packets()

View file

@ -11,7 +11,7 @@
#include <AK/SinglyLinkedList.h> #include <AK/SinglyLinkedList.h>
#include <AK/WeakPtr.h> #include <AK/WeakPtr.h>
#include <Kernel/KResult.h> #include <Kernel/KResult.h>
#include <Kernel/Locking/Lockable.h> #include <Kernel/Locking/ProtectedValue.h>
#include <Kernel/Net/IPv4Socket.h> #include <Kernel/Net/IPv4Socket.h>
namespace Kernel { namespace Kernel {
@ -142,10 +142,10 @@ public:
bool should_delay_next_ack() const; bool should_delay_next_ack() const;
static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& sockets_by_tuple(); static ProtectedValue<HashMap<IPv4SocketTuple, TCPSocket*>>& sockets_by_tuple();
static RefPtr<TCPSocket> from_tuple(const IPv4SocketTuple& tuple); static RefPtr<TCPSocket> from_tuple(const IPv4SocketTuple& tuple);
static Lockable<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>& closing_sockets(); static ProtectedValue<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>& closing_sockets();
RefPtr<TCPSocket> create_client(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port); RefPtr<TCPSocket> create_client(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port);
void set_originator(TCPSocket& originator) { m_originator = originator; } void set_originator(TCPSocket& originator) { m_originator = originator; }
@ -153,7 +153,7 @@ public:
void release_to_originator(); void release_to_originator();
void release_for_accept(RefPtr<TCPSocket>); void release_for_accept(RefPtr<TCPSocket>);
static Lockable<HashTable<TCPSocket*>>& sockets_for_retransmit(); static ProtectedValue<HashTable<TCPSocket*>>& sockets_for_retransmit();
void retransmit_packets(); void retransmit_packets();
virtual KResult close() override; virtual KResult close() override;