diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index e14ba323bb..d22ef9992d 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -26,11 +26,11 @@ namespace Kernel { -static AK::Singleton>> s_table; +static AK::Singleton>> s_table; using BlockFlags = Thread::FileDescriptionBlocker::BlockFlags; -Lockable>& IPv4Socket::all_sockets() +ProtectedValue>& IPv4Socket::all_sockets() { return *s_table; } @@ -77,14 +77,17 @@ IPv4Socket::IPv4Socket(int type, int protocol, NonnullOwnPtr recei if (m_buffer_mode == BufferMode::Bytes) { VERIFY(m_scratch_buffer); } - MutexLocker locker(all_sockets().lock()); - all_sockets().resource().set(this); + + all_sockets().with_exclusive([&](auto& table) { + table.set(this); + }); } IPv4Socket::~IPv4Socket() { - MutexLocker locker(all_sockets().lock()); - all_sockets().resource().remove(this); + all_sockets().with_exclusive([&](auto& table) { + table.remove(this); + }); } void IPv4Socket::get_local_address(sockaddr* address, socklen_t* address_size) diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h index 811b9be724..318a43f558 100644 --- a/Kernel/Net/IPv4Socket.h +++ b/Kernel/Net/IPv4Socket.h @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include @@ -31,7 +31,7 @@ public: static KResultOr> create(int type, int protocol); virtual ~IPv4Socket() override; - static Lockable>& all_sockets(); + static ProtectedValue>& all_sockets(); virtual KResult close() override; virtual KResult bind(Userspace, socklen_t) override; diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp index 14b2001fbc..2ea3007335 100644 --- a/Kernel/Net/NetworkTask.cpp +++ b/Kernel/Net/NetworkTask.cpp @@ -223,14 +223,10 @@ void handle_icmp(EthernetFrameHeader const& eth, IPv4Packet const& ipv4_packet, { NonnullRefPtrVector icmp_sockets; - { - MutexLocker locker(IPv4Socket::all_sockets().lock(), Mutex::Mode::Shared); - for (auto* socket : IPv4Socket::all_sockets().resource()) { - if (socket->protocol() != (unsigned)IPv4Protocol::ICMP) - continue; + IPv4Socket::all_sockets().for_each_shared([&](const auto& socket) { + if (socket->protocol() == (unsigned)IPv4Protocol::ICMP) icmp_sockets.append(*socket); - } - } + }); for (auto& socket : icmp_sockets) socket.did_receive(ipv4_packet.source(), 0, { &ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size() }, packet_timestamp); } diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index 9cd0a791a9..9da7b83ca9 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace Kernel {