diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index 1c719bb516..ef5783137f 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -6,7 +6,6 @@ #include #include -#include #include #include #include @@ -18,30 +17,29 @@ namespace Kernel { void UDPSocket::for_each(Function callback) { - MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared); - for (auto it : sockets_by_port().resource()) - callback(*it.value); + sockets_by_port().for_each_shared([&](const auto& socket) { + callback(*socket.value); + }); } -static AK::Singleton>> s_map; +static AK::Singleton>> s_map; -Lockable>& UDPSocket::sockets_by_port() +ProtectedValue>& UDPSocket::sockets_by_port() { return *s_map; } SocketHandle UDPSocket::from_port(u16 port) { - RefPtr socket; - { - MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared); - auto it = sockets_by_port().resource().find(port); - if (it == sockets_by_port().resource().end()) + return sockets_by_port().with_shared([&](const auto& table) -> SocketHandle { + RefPtr socket; + auto it = table.find(port); + if (it == table.end()) return {}; socket = (*it).value; VERIFY(socket); - } - return { *socket }; + return { *socket }; + }); } UDPSocket::UDPSocket(int protocol, NonnullOwnPtr receive_buffer) @@ -51,8 +49,9 @@ UDPSocket::UDPSocket(int protocol, NonnullOwnPtr receive_buffer) UDPSocket::~UDPSocket() { - MutexLocker locker(sockets_by_port().lock()); - sockets_by_port().resource().remove(local_port()); + sockets_by_port().with_exclusive([&](auto& table) { + table.remove(local_port()); + }); } KResultOr> UDPSocket::create(int protocol, NonnullOwnPtr receive_buffer) @@ -113,30 +112,32 @@ KResultOr UDPSocket::protocol_allocate_local_port() constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; - MutexLocker locker(sockets_by_port().lock()); - for (u16 port = first_scan_port;;) { - auto it = sockets_by_port().resource().find(port); - if (it == sockets_by_port().resource().end()) { - set_local_port(port); - sockets_by_port().resource().set(port, this); - return port; + return sockets_by_port().with_exclusive([&](auto& table) -> KResultOr { + for (u16 port = first_scan_port;;) { + auto it = table.find(port); + if (it == table.end()) { + set_local_port(port); + table.set(port, this); + return port; + } + ++port; + if (port > last_ephemeral_port) + port = first_ephemeral_port; + if (port == first_scan_port) + break; } - ++port; - if (port > last_ephemeral_port) - port = first_ephemeral_port; - if (port == first_scan_port) - break; - } - return EADDRINUSE; + return EADDRINUSE; + }); } KResult UDPSocket::protocol_bind() { - MutexLocker locker(sockets_by_port().lock()); - if (sockets_by_port().resource().contains(local_port())) - return EADDRINUSE; - sockets_by_port().resource().set(local_port(), this); - return KSuccess; + return sockets_by_port().with_exclusive([&](auto& table) -> KResult { + if (table.contains(local_port())) + return EADDRINUSE; + table.set(local_port(), this); + return KSuccess; + }); } } diff --git a/Kernel/Net/UDPSocket.h b/Kernel/Net/UDPSocket.h index b13b07ec91..d57ff39445 100644 --- a/Kernel/Net/UDPSocket.h +++ b/Kernel/Net/UDPSocket.h @@ -7,7 +7,7 @@ #pragma once #include -#include +#include #include namespace Kernel { @@ -23,7 +23,7 @@ public: private: explicit UDPSocket(int protocol, NonnullOwnPtr receive_buffer); virtual StringView class_name() const override { return "UDPSocket"; } - static Lockable>& sockets_by_port(); + static ProtectedValue>& sockets_by_port(); virtual KResultOr protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override; virtual KResultOr protocol_send(const UserOrKernelBuffer&, size_t) override;