diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp index 2ea3007335..6adc79953d 100644 --- a/Kernel/Net/NetworkTask.cpp +++ b/Kernel/Net/NetworkTask.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -639,11 +640,9 @@ void retransmit_tcp_packets() // We must keep the sockets alive until after we've unlocked the hash table // in case retransmit_packets() realizes that it wants to close the socket. NonnullRefPtrVector sockets; - { - MutexLocker locker(TCPSocket::sockets_for_retransmit().lock(), LockMode::Shared); - for (auto& socket : TCPSocket::sockets_for_retransmit().resource()) - sockets.append(*socket); - } + TCPSocket::sockets_for_retransmit().for_each_shared([&](const auto& socket) { + sockets.append(*socket); + }); for (auto& socket : sockets) { MutexLocker socket_locker(socket.lock()); diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index e555899169..faf3352a41 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include @@ -24,9 +24,9 @@ namespace Kernel { void TCPSocket::for_each(Function callback) { - MutexLocker locker(sockets_by_tuple().lock(), Mutex::Mode::Shared); - for (auto& it : sockets_by_tuple().resource()) + sockets_by_tuple().for_each_shared([&](const auto& it) { callback(*it.value); + }); } void TCPSocket::set_state(State new_state) @@ -42,8 +42,9 @@ void TCPSocket::set_state(State new_state) m_role = Role::Connected; if (new_state == State::Closed) { - MutexLocker locker(closing_sockets().lock()); - closing_sockets().resource().remove(tuple()); + closing_sockets().with_exclusive([&](auto& table) { + table.remove(tuple()); + }); if (m_originator) release_to_originator(); @@ -53,71 +54,68 @@ void TCPSocket::set_state(State new_state) evaluate_block_conditions(); } -static AK::Singleton>>> s_socket_closing; +static AK::Singleton>>> s_socket_closing; -Lockable>>& TCPSocket::closing_sockets() +ProtectedValue>>& TCPSocket::closing_sockets() { return *s_socket_closing; } -static AK::Singleton>> s_socket_tuples; +static AK::Singleton>> s_socket_tuples; -Lockable>& TCPSocket::sockets_by_tuple() +ProtectedValue>& TCPSocket::sockets_by_tuple() { return *s_socket_tuples; } RefPtr TCPSocket::from_tuple(const IPv4SocketTuple& tuple) { - MutexLocker locker(sockets_by_tuple().lock(), Mutex::Mode::Shared); + return sockets_by_tuple().with_shared([&](const auto& table) -> RefPtr { + auto exact_match = table.get(tuple); + if (exact_match.has_value()) + return { *exact_match.value() }; - auto exact_match = sockets_by_tuple().resource().get(tuple); - if (exact_match.has_value()) - return { *exact_match.value() }; + auto address_tuple = IPv4SocketTuple(tuple.local_address(), tuple.local_port(), IPv4Address(), 0); + auto address_match = table.get(address_tuple); + if (address_match.has_value()) + return { *address_match.value() }; - auto address_tuple = IPv4SocketTuple(tuple.local_address(), tuple.local_port(), IPv4Address(), 0); - auto address_match = sockets_by_tuple().resource().get(address_tuple); - if (address_match.has_value()) - return { *address_match.value() }; + auto wildcard_tuple = IPv4SocketTuple(IPv4Address(), tuple.local_port(), IPv4Address(), 0); + auto wildcard_match = table.get(wildcard_tuple); + if (wildcard_match.has_value()) + return { *wildcard_match.value() }; - auto wildcard_tuple = IPv4SocketTuple(IPv4Address(), tuple.local_port(), IPv4Address(), 0); - auto wildcard_match = sockets_by_tuple().resource().get(wildcard_tuple); - if (wildcard_match.has_value()) - return { *wildcard_match.value() }; - - return {}; + return {}; + }); } RefPtr TCPSocket::create_client(const IPv4Address& new_local_address, u16 new_local_port, const IPv4Address& new_peer_address, u16 new_peer_port) { auto tuple = IPv4SocketTuple(new_local_address, new_local_port, new_peer_address, new_peer_port); - - { - MutexLocker locker(sockets_by_tuple().lock(), Mutex::Mode::Shared); - if (sockets_by_tuple().resource().contains(tuple)) + return sockets_by_tuple().with_exclusive([&](auto& table) -> RefPtr { + if (table.contains(tuple)) return {}; - } - auto receive_buffer = create_receive_buffer(); - if (!receive_buffer) - return {}; - auto result = TCPSocket::create(protocol(), receive_buffer.release_nonnull()); - if (result.is_error()) - return {}; + auto receive_buffer = create_receive_buffer(); + if (!receive_buffer) + return {}; + auto result = TCPSocket::create(protocol(), receive_buffer.release_nonnull()); + if (result.is_error()) + return {}; - auto client = result.release_value(); - client->set_setup_state(SetupState::InProgress); - client->set_local_address(new_local_address); - client->set_local_port(new_local_port); - client->set_peer_address(new_peer_address); - client->set_peer_port(new_peer_port); - client->set_direction(Direction::Incoming); - client->set_originator(*this); + auto client = result.release_value(); + client->set_setup_state(SetupState::InProgress); + client->set_local_address(new_local_address); + client->set_local_port(new_local_port); + client->set_peer_address(new_peer_address); + client->set_peer_port(new_peer_port); + client->set_direction(Direction::Incoming); + client->set_originator(*this); - MutexLocker locker(sockets_by_tuple().lock()); - m_pending_release_for_accept.set(tuple, client); - sockets_by_tuple().resource().set(tuple, client); + m_pending_release_for_accept.set(tuple, client); + table.set(tuple, client); - return client; + return { move(client) }; + }); } void TCPSocket::release_to_originator() @@ -143,8 +141,9 @@ TCPSocket::TCPSocket(int protocol, NonnullOwnPtr receive_buffer, O TCPSocket::~TCPSocket() { - MutexLocker locker(sockets_by_tuple().lock()); - sockets_by_tuple().resource().remove(tuple()); + sockets_by_tuple().with_exclusive([&](auto& table) { + table.remove(tuple()); + }); dequeue_for_retransmit(); @@ -378,10 +377,14 @@ KResult TCPSocket::protocol_bind() KResult TCPSocket::protocol_listen(bool did_allocate_port) { if (!did_allocate_port) { - MutexLocker socket_locker(sockets_by_tuple().lock()); - if (sockets_by_tuple().resource().contains(tuple())) + bool ok = sockets_by_tuple().with_exclusive([&](auto& table) -> bool { + if (table.contains(tuple())) + return false; + table.set(tuple(), this); + return true; + }); + if (!ok) return EADDRINUSE; - sockets_by_tuple().resource().set(tuple(), this); } set_direction(Direction::Passive); @@ -443,23 +446,24 @@ KResultOr TCPSocket::protocol_allocate_local_port() constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; - MutexLocker locker(sockets_by_tuple().lock()); - for (u16 port = first_scan_port;;) { - IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); + return sockets_by_tuple().with_exclusive([&](auto& table) -> KResultOr { + for (u16 port = first_scan_port;;) { + IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); - auto it = sockets_by_tuple().resource().find(proposed_tuple); - if (it == sockets_by_tuple().resource().end()) { - set_local_port(port); - sockets_by_tuple().resource().set(proposed_tuple, this); - return port; + auto it = table.find(proposed_tuple); + if (it == table.end()) { + set_local_port(port); + table.set(proposed_tuple, this); + return port; + } + ++port; + if (port > last_ephemeral_port) + port = first_ephemeral_port; + if (port == first_scan_port) + break; } - ++port; - if (port > last_ephemeral_port) - port = first_ephemeral_port; - if (port == first_scan_port) - break; - } - return EADDRINUSE; + return EADDRINUSE; + }); } bool TCPSocket::protocol_is_disconnected() const @@ -499,30 +503,32 @@ KResult TCPSocket::close() set_state(State::LastAck); } - if (state() != State::Closed && state() != State::Listen) { - MutexLocker locker(closing_sockets().lock()); - closing_sockets().resource().set(tuple(), *this); - } + if (state() != State::Closed && state() != State::Listen) + closing_sockets().with_exclusive([&](auto& table) { + table.set(tuple(), *this); + }); return result; } -static AK::Singleton>> s_sockets_for_retransmit; +static AK::Singleton>> s_sockets_for_retransmit; -Lockable>& TCPSocket::sockets_for_retransmit() +ProtectedValue>& TCPSocket::sockets_for_retransmit() { return *s_sockets_for_retransmit; } void TCPSocket::enqueue_for_retransmit() { - MutexLocker locker(sockets_for_retransmit().lock()); - sockets_for_retransmit().resource().set(this); + sockets_for_retransmit().with_exclusive([&](auto& table) { + table.set(this); + }); } void TCPSocket::dequeue_for_retransmit() { - MutexLocker locker(sockets_for_retransmit().lock()); - sockets_for_retransmit().resource().remove(this); + sockets_for_retransmit().with_exclusive([&](auto& table) { + table.remove(this); + }); } void TCPSocket::retransmit_packets() diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index 9da7b83ca9..9422f78598 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include namespace Kernel { @@ -142,10 +142,10 @@ public: bool should_delay_next_ack() const; - static Lockable>& sockets_by_tuple(); + static ProtectedValue>& sockets_by_tuple(); static RefPtr from_tuple(const IPv4SocketTuple& tuple); - static Lockable>>& closing_sockets(); + static ProtectedValue>>& closing_sockets(); RefPtr create_client(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port); void set_originator(TCPSocket& originator) { m_originator = originator; } @@ -153,7 +153,7 @@ public: void release_to_originator(); void release_for_accept(RefPtr); - static Lockable>& sockets_for_retransmit(); + static ProtectedValue>& sockets_for_retransmit(); void retransmit_packets(); virtual KResult close() override;