1
Fork 0
mirror of https://github.com/RGBCube/serenity synced 2025-07-24 18:57:42 +00:00

Kernel: Migrate UDP socket table locking to ProtectedValue

This commit is contained in:
Jean-Baptiste Boric 2021-07-18 12:24:34 +02:00 committed by Andreas Kling
parent 9216c72bfe
commit 9517100672
2 changed files with 37 additions and 36 deletions

View file

@ -6,7 +6,6 @@
#include <AK/Singleton.h> #include <AK/Singleton.h>
#include <Kernel/Devices/RandomDevice.h> #include <Kernel/Devices/RandomDevice.h>
#include <Kernel/Locking/Mutex.h>
#include <Kernel/Net/NetworkAdapter.h> #include <Kernel/Net/NetworkAdapter.h>
#include <Kernel/Net/Routing.h> #include <Kernel/Net/Routing.h>
#include <Kernel/Net/UDP.h> #include <Kernel/Net/UDP.h>
@ -18,30 +17,29 @@ namespace Kernel {
void UDPSocket::for_each(Function<void(const UDPSocket&)> callback) void UDPSocket::for_each(Function<void(const UDPSocket&)> callback)
{ {
MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared); sockets_by_port().for_each_shared([&](const auto& socket) {
for (auto it : sockets_by_port().resource()) callback(*socket.value);
callback(*it.value); });
} }
static AK::Singleton<Lockable<HashMap<u16, UDPSocket*>>> s_map; static AK::Singleton<ProtectedValue<HashMap<u16, UDPSocket*>>> s_map;
Lockable<HashMap<u16, UDPSocket*>>& UDPSocket::sockets_by_port() ProtectedValue<HashMap<u16, UDPSocket*>>& UDPSocket::sockets_by_port()
{ {
return *s_map; return *s_map;
} }
SocketHandle<UDPSocket> UDPSocket::from_port(u16 port) SocketHandle<UDPSocket> UDPSocket::from_port(u16 port)
{ {
RefPtr<UDPSocket> socket; return sockets_by_port().with_shared([&](const auto& table) -> SocketHandle<UDPSocket> {
{ RefPtr<UDPSocket> socket;
MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared); auto it = table.find(port);
auto it = sockets_by_port().resource().find(port); if (it == table.end())
if (it == sockets_by_port().resource().end())
return {}; return {};
socket = (*it).value; socket = (*it).value;
VERIFY(socket); VERIFY(socket);
} return { *socket };
return { *socket }; });
} }
UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer) UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
@ -51,8 +49,9 @@ UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
UDPSocket::~UDPSocket() UDPSocket::~UDPSocket()
{ {
MutexLocker locker(sockets_by_port().lock()); sockets_by_port().with_exclusive([&](auto& table) {
sockets_by_port().resource().remove(local_port()); table.remove(local_port());
});
} }
KResultOr<NonnullRefPtr<UDPSocket>> UDPSocket::create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer) KResultOr<NonnullRefPtr<UDPSocket>> UDPSocket::create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
@ -113,30 +112,32 @@ KResultOr<u16> UDPSocket::protocol_allocate_local_port()
constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size; u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size;
MutexLocker locker(sockets_by_port().lock()); return sockets_by_port().with_exclusive([&](auto& table) -> KResultOr<u16> {
for (u16 port = first_scan_port;;) { for (u16 port = first_scan_port;;) {
auto it = sockets_by_port().resource().find(port); auto it = table.find(port);
if (it == sockets_by_port().resource().end()) { if (it == table.end()) {
set_local_port(port); set_local_port(port);
sockets_by_port().resource().set(port, this); table.set(port, this);
return port; return port;
}
++port;
if (port > last_ephemeral_port)
port = first_ephemeral_port;
if (port == first_scan_port)
break;
} }
++port; return EADDRINUSE;
if (port > last_ephemeral_port) });
port = first_ephemeral_port;
if (port == first_scan_port)
break;
}
return EADDRINUSE;
} }
KResult UDPSocket::protocol_bind() KResult UDPSocket::protocol_bind()
{ {
MutexLocker locker(sockets_by_port().lock()); return sockets_by_port().with_exclusive([&](auto& table) -> KResult {
if (sockets_by_port().resource().contains(local_port())) if (table.contains(local_port()))
return EADDRINUSE; return EADDRINUSE;
sockets_by_port().resource().set(local_port(), this); table.set(local_port(), this);
return KSuccess; return KSuccess;
});
} }
} }

View file

@ -7,7 +7,7 @@
#pragma once #pragma once
#include <Kernel/KResult.h> #include <Kernel/KResult.h>
#include <Kernel/Locking/Lockable.h> #include <Kernel/Locking/ProtectedValue.h>
#include <Kernel/Net/IPv4Socket.h> #include <Kernel/Net/IPv4Socket.h>
namespace Kernel { namespace Kernel {
@ -23,7 +23,7 @@ public:
private: private:
explicit UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer); explicit UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
virtual StringView class_name() const override { return "UDPSocket"; } virtual StringView class_name() const override { return "UDPSocket"; }
static Lockable<HashMap<u16, UDPSocket*>>& sockets_by_port(); static ProtectedValue<HashMap<u16, UDPSocket*>>& sockets_by_port();
virtual KResultOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override; virtual KResultOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override; virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;