1
Fork 0
mirror of https://github.com/RGBCube/serenity synced 2025-05-28 17:45:09 +00:00

Kernel/Net: Iron out the locking mechanism across the subsystem

There is a big mix of LockRefPtrs all over the Networking subsystem, as
well as lots of room for improvements with our locking patterns, which
this commit will not pursue, but will give a good start for such work.

To deal with this situation, we change the following things:
- Creating instances of NetworkAdapter should always yield a non-locking
  NonnullRefPtr. Acquiring an instance from the NetworkingManagement
  should give a simple RefPtr,as giving LockRefPtr does not really
  protect from concurrency problems in such case.
- Since NetworkingManagement works with normal RefPtrs we should
  protect all instances of RefPtr<NetworkAdapter> with SpinlockProtected
  to ensure references are gone unexpectedly.
- Protect the so_error class member with a proper spinlock. This happens
  to be important because the clear_so_error() method lacked any proper
  locking measures. It also helps preventing a possible TOCTOU when we
  might do a more fine-grained locking in the Socket code, so this could
  be definitely a start for this.
- Change unnecessary LockRefPtr<PacketWithTimestamp> in the structure
  of OutgoingPacket to a simple RefPtr<PacketWithTimestamp> as the whole
  list should be MutexProtected.
This commit is contained in:
Liav A 2023-04-11 03:50:15 +03:00 committed by Linus Groh
parent bd7d4513bf
commit 7c1f645e27
13 changed files with 93 additions and 83 deletions

View file

@ -212,7 +212,8 @@ ErrorOr<size_t> IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer cons
return set_so_error(EPIPE); return set_so_error(EPIPE);
auto allow_using_gateway = ((flags & MSG_DONTROUTE) || m_routing_disabled) ? AllowUsingGateway::No : AllowUsingGateway::Yes; auto allow_using_gateway = ((flags & MSG_DONTROUTE) || m_routing_disabled) ? AllowUsingGateway::No : AllowUsingGateway::Yes;
auto routing_decision = route_to(m_peer_address, m_local_address, bound_interface(), allow_using_gateway); auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr<NetworkAdapter> { return bound_device; });
auto routing_decision = route_to(m_peer_address, m_local_address, adapter, allow_using_gateway);
if (routing_decision.is_zero()) if (routing_decision.is_zero())
return set_so_error(EHOSTUNREACH); return set_so_error(EHOSTUNREACH);

View file

@ -111,9 +111,9 @@ size_t NetworkAdapter::dequeue_packet(u8* buffer, size_t buffer_size, Time& pack
return packet_size; return packet_size;
} }
LockRefPtr<PacketWithTimestamp> NetworkAdapter::acquire_packet_buffer(size_t size) RefPtr<PacketWithTimestamp> NetworkAdapter::acquire_packet_buffer(size_t size)
{ {
auto packet = m_unused_packets.with([size](auto& unused_packets) -> LockRefPtr<PacketWithTimestamp> { auto packet = m_unused_packets.with([size](auto& unused_packets) -> RefPtr<PacketWithTimestamp> {
if (unused_packets.is_empty()) if (unused_packets.is_empty())
return nullptr; return nullptr;
@ -135,7 +135,7 @@ LockRefPtr<PacketWithTimestamp> NetworkAdapter::acquire_packet_buffer(size_t siz
auto buffer_or_error = KBuffer::try_create_with_size("NetworkAdapter: Packet buffer"sv, size, Memory::Region::Access::ReadWrite, AllocationStrategy::AllocateNow); auto buffer_or_error = KBuffer::try_create_with_size("NetworkAdapter: Packet buffer"sv, size, Memory::Region::Access::ReadWrite, AllocationStrategy::AllocateNow);
if (buffer_or_error.is_error()) if (buffer_or_error.is_error())
return {}; return {};
packet = adopt_lock_ref_if_nonnull(new (nothrow) PacketWithTimestamp { buffer_or_error.release_value(), kgettimeofday() }); packet = adopt_ref_if_nonnull(new (nothrow) PacketWithTimestamp { buffer_or_error.release_value(), kgettimeofday() });
if (!packet) if (!packet)
return {}; return {};
packet->buffer->set_size(size); packet->buffer->set_size(size);

View file

@ -39,7 +39,7 @@ struct PacketWithTimestamp final : public AtomicRefCounted<PacketWithTimestamp>
NonnullOwnPtr<KBuffer> buffer; NonnullOwnPtr<KBuffer> buffer;
Time timestamp; Time timestamp;
IntrusiveListNode<PacketWithTimestamp, LockRefPtr<PacketWithTimestamp>> packet_node; IntrusiveListNode<PacketWithTimestamp, RefPtr<PacketWithTimestamp>> packet_node;
}; };
class NetworkingManagement; class NetworkingManagement;
@ -91,7 +91,7 @@ public:
u32 packets_out() const { return m_packets_out; } u32 packets_out() const { return m_packets_out; }
u32 bytes_out() const { return m_bytes_out; } u32 bytes_out() const { return m_bytes_out; }
LockRefPtr<PacketWithTimestamp> acquire_packet_buffer(size_t); RefPtr<PacketWithTimestamp> acquire_packet_buffer(size_t);
void release_packet_buffer(PacketWithTimestamp&); void release_packet_buffer(PacketWithTimestamp&);
constexpr size_t layer3_payload_offset() const { return sizeof(EthernetFrameHeader); } constexpr size_t layer3_payload_offset() const { return sizeof(EthernetFrameHeader); }

View file

@ -31,7 +31,7 @@ static void handle_icmp(EthernetFrameHeader const&, IPv4Packet const&, Time cons
static void handle_udp(IPv4Packet const&, Time const& packet_timestamp); static void handle_udp(IPv4Packet const&, Time const& packet_timestamp);
static void handle_tcp(IPv4Packet const&, Time const& packet_timestamp); static void handle_tcp(IPv4Packet const&, Time const& packet_timestamp);
static void send_delayed_tcp_ack(TCPSocket& socket); static void send_delayed_tcp_ack(TCPSocket& socket);
static void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, LockRefPtr<NetworkAdapter> adapter); static void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, RefPtr<NetworkAdapter> adapter);
static void flush_delayed_tcp_acks(); static void flush_delayed_tcp_acks();
static void retransmit_tcp_packets(); static void retransmit_tcp_packets();
@ -333,7 +333,7 @@ void flush_delayed_tcp_acks()
} }
} }
void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, LockRefPtr<NetworkAdapter> adapter) void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, RefPtr<NetworkAdapter> adapter)
{ {
auto routing_decision = route_to(ipv4_packet.source(), ipv4_packet.destination(), adapter); auto routing_decision = route_to(ipv4_packet.source(), ipv4_packet.destination(), adapter);
if (routing_decision.is_zero()) if (routing_decision.is_zero())

View file

@ -35,7 +35,7 @@ UNMAP_AFTER_INIT NetworkingManagement::NetworkingManagement()
{ {
} }
NonnullLockRefPtr<NetworkAdapter> NetworkingManagement::loopback_adapter() const NonnullRefPtr<NetworkAdapter> NetworkingManagement::loopback_adapter() const
{ {
return *m_loopback_adapter; return *m_loopback_adapter;
} }
@ -56,13 +56,13 @@ ErrorOr<void> NetworkingManagement::try_for_each(Function<ErrorOr<void>(NetworkA
}); });
} }
LockRefPtr<NetworkAdapter> NetworkingManagement::from_ipv4_address(IPv4Address const& address) const RefPtr<NetworkAdapter> NetworkingManagement::from_ipv4_address(IPv4Address const& address) const
{ {
if (address[0] == 0 && address[1] == 0 && address[2] == 0 && address[3] == 0) if (address[0] == 0 && address[1] == 0 && address[2] == 0 && address[3] == 0)
return m_loopback_adapter; return m_loopback_adapter;
if (address[0] == 127) if (address[0] == 127)
return m_loopback_adapter; return m_loopback_adapter;
return m_adapters.with([&](auto& adapters) -> LockRefPtr<NetworkAdapter> { return m_adapters.with([&](auto& adapters) -> RefPtr<NetworkAdapter> {
for (auto& adapter : adapters) { for (auto& adapter : adapters) {
if (adapter->ipv4_address() == address || adapter->ipv4_broadcast() == address) if (adapter->ipv4_address() == address || adapter->ipv4_broadcast() == address)
return adapter; return adapter;
@ -71,9 +71,9 @@ LockRefPtr<NetworkAdapter> NetworkingManagement::from_ipv4_address(IPv4Address c
}); });
} }
LockRefPtr<NetworkAdapter> NetworkingManagement::lookup_by_name(StringView name) const RefPtr<NetworkAdapter> NetworkingManagement::lookup_by_name(StringView name) const
{ {
return m_adapters.with([&](auto& adapters) -> LockRefPtr<NetworkAdapter> { return m_adapters.with([&](auto& adapters) -> RefPtr<NetworkAdapter> {
for (auto& adapter : adapters) { for (auto& adapter : adapters) {
if (adapter->name() == name) if (adapter->name() == name)
return adapter; return adapter;

View file

@ -8,9 +8,9 @@
#include <AK/Function.h> #include <AK/Function.h>
#include <AK/NonnullOwnPtr.h> #include <AK/NonnullOwnPtr.h>
#include <AK/RefPtr.h>
#include <AK/Types.h> #include <AK/Types.h>
#include <Kernel/Bus/PCI/Definitions.h> #include <Kernel/Bus/PCI/Definitions.h>
#include <Kernel/Library/NonnullLockRefPtr.h>
#include <Kernel/Locking/SpinlockProtected.h> #include <Kernel/Locking/SpinlockProtected.h>
#include <Kernel/Memory/Region.h> #include <Kernel/Memory/Region.h>
#include <Kernel/Net/NetworkAdapter.h> #include <Kernel/Net/NetworkAdapter.h>
@ -33,16 +33,16 @@ public:
void for_each(Function<void(NetworkAdapter&)>); void for_each(Function<void(NetworkAdapter&)>);
ErrorOr<void> try_for_each(Function<ErrorOr<void>(NetworkAdapter&)>); ErrorOr<void> try_for_each(Function<ErrorOr<void>(NetworkAdapter&)>);
LockRefPtr<NetworkAdapter> from_ipv4_address(IPv4Address const&) const; RefPtr<NetworkAdapter> from_ipv4_address(IPv4Address const&) const;
LockRefPtr<NetworkAdapter> lookup_by_name(StringView) const; RefPtr<NetworkAdapter> lookup_by_name(StringView) const;
NonnullLockRefPtr<NetworkAdapter> loopback_adapter() const; NonnullRefPtr<NetworkAdapter> loopback_adapter() const;
private: private:
ErrorOr<NonnullRefPtr<NetworkAdapter>> determine_network_device(PCI::DeviceIdentifier const&) const; ErrorOr<NonnullRefPtr<NetworkAdapter>> determine_network_device(PCI::DeviceIdentifier const&) const;
SpinlockProtected<Vector<NonnullLockRefPtr<NetworkAdapter>>, LockRank::None> m_adapters {}; SpinlockProtected<Vector<NonnullRefPtr<NetworkAdapter>>, LockRank::None> m_adapters {};
LockRefPtr<NetworkAdapter> m_loopback_adapter; RefPtr<NetworkAdapter> m_loopback_adapter;
}; };
} }

View file

@ -135,11 +135,11 @@ SpinlockProtected<Route::RouteList, LockRank::None>& routing_table()
return *s_routing_table; return *s_routing_table;
} }
ErrorOr<void> update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, LockRefPtr<NetworkAdapter> adapter, UpdateTable update) ErrorOr<void> update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, RefPtr<NetworkAdapter> adapter, UpdateTable update)
{ {
dbgln_if(ROUTING_DEBUG, "update_routing_table {} {} {} {} {} {}", destination, gateway, netmask, flags, adapter, update == UpdateTable::Set ? "Set" : "Delete"); dbgln_if(ROUTING_DEBUG, "update_routing_table {} {} {} {} {} {}", destination, gateway, netmask, flags, adapter, update == UpdateTable::Set ? "Set" : "Delete");
auto route_entry = adopt_lock_ref_if_nonnull(new (nothrow) Route { destination, gateway, netmask, flags, adapter.release_nonnull() }); auto route_entry = adopt_ref_if_nonnull(new (nothrow) Route { destination, gateway, netmask, flags, adapter.release_nonnull() });
if (!route_entry) if (!route_entry)
return ENOMEM; return ENOMEM;
@ -178,7 +178,7 @@ static MACAddress multicast_ethernet_address(IPv4Address const& address)
return MACAddress { 0x01, 0x00, 0x5e, (u8)(address[1] & 0x7f), address[2], address[3] }; return MACAddress { 0x01, 0x00, 0x5e, (u8)(address[1] & 0x7f), address[2], address[3] };
} }
RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, LockRefPtr<NetworkAdapter> const through, AllowUsingGateway allow_using_gateway) RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, RefPtr<NetworkAdapter> const through, AllowUsingGateway allow_using_gateway)
{ {
auto matches = [&](auto& adapter) { auto matches = [&](auto& adapter) {
if (!through) if (!through)
@ -200,8 +200,8 @@ RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, L
auto target_addr = target.to_u32(); auto target_addr = target.to_u32();
auto source_addr = source.to_u32(); auto source_addr = source.to_u32();
LockRefPtr<NetworkAdapter> local_adapter = nullptr; RefPtr<NetworkAdapter> local_adapter = nullptr;
LockRefPtr<Route> chosen_route = nullptr; RefPtr<Route> chosen_route = nullptr;
NetworkingManagement::the().for_each([source_addr, &target_addr, &local_adapter, &matches, &through](NetworkAdapter& adapter) { NetworkingManagement::the().for_each([source_addr, &target_addr, &local_adapter, &matches, &through](NetworkAdapter& adapter) {
auto adapter_addr = adapter.ipv4_address().to_u32(); auto adapter_addr = adapter.ipv4_address().to_u32();
@ -263,7 +263,7 @@ RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, L
return { nullptr, {} }; return { nullptr, {} };
} }
LockRefPtr<NetworkAdapter> adapter = nullptr; RefPtr<NetworkAdapter> adapter = nullptr;
IPv4Address next_hop_ip; IPv4Address next_hop_ip;
if (local_adapter) { if (local_adapter) {

View file

@ -7,7 +7,7 @@
#pragma once #pragma once
#include <AK/IPv4Address.h> #include <AK/IPv4Address.h>
#include <Kernel/Library/NonnullLockRefPtr.h> #include <AK/RefPtr.h>
#include <Kernel/Locking/MutexProtected.h> #include <Kernel/Locking/MutexProtected.h>
#include <Kernel/Net/NetworkAdapter.h> #include <Kernel/Net/NetworkAdapter.h>
#include <Kernel/Thread.h> #include <Kernel/Thread.h>
@ -15,7 +15,7 @@
namespace Kernel { namespace Kernel {
struct Route final : public AtomicRefCounted<Route> { struct Route final : public AtomicRefCounted<Route> {
Route(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, NonnullLockRefPtr<NetworkAdapter> adapter) Route(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, NonnullRefPtr<NetworkAdapter> adapter)
: destination(destination) : destination(destination)
, gateway(gateway) , gateway(gateway)
, netmask(netmask) , netmask(netmask)
@ -38,14 +38,14 @@ struct Route final : public AtomicRefCounted<Route> {
const IPv4Address gateway; const IPv4Address gateway;
const IPv4Address netmask; const IPv4Address netmask;
const u16 flags; const u16 flags;
NonnullLockRefPtr<NetworkAdapter> adapter; NonnullRefPtr<NetworkAdapter> const adapter;
IntrusiveListNode<Route, LockRefPtr<Route>> route_list_node {}; IntrusiveListNode<Route, RefPtr<Route>> route_list_node {};
using RouteList = IntrusiveList<&Route::route_list_node>; using RouteList = IntrusiveList<&Route::route_list_node>;
}; };
struct RoutingDecision { struct RoutingDecision {
LockRefPtr<NetworkAdapter> adapter; RefPtr<NetworkAdapter> adapter;
MACAddress next_hop; MACAddress next_hop;
bool is_zero() const; bool is_zero() const;
@ -57,14 +57,14 @@ enum class UpdateTable {
}; };
void update_arp_table(IPv4Address const&, MACAddress const&, UpdateTable update); void update_arp_table(IPv4Address const&, MACAddress const&, UpdateTable update);
ErrorOr<void> update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, LockRefPtr<NetworkAdapter> const adapter, UpdateTable update); ErrorOr<void> update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, RefPtr<NetworkAdapter> const adapter, UpdateTable update);
enum class AllowUsingGateway { enum class AllowUsingGateway {
Yes, Yes,
No, No,
}; };
RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, LockRefPtr<NetworkAdapter> const through = nullptr, AllowUsingGateway = AllowUsingGateway::Yes); RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, RefPtr<NetworkAdapter> const through = nullptr, AllowUsingGateway = AllowUsingGateway::Yes);
SpinlockProtected<HashMap<IPv4Address, MACAddress>, LockRank::None>& arp_table(); SpinlockProtected<HashMap<IPv4Address, MACAddress>, LockRank::None>& arp_table();
SpinlockProtected<Route::RouteList, LockRank::None>& routing_table(); SpinlockProtected<Route::RouteList, LockRank::None>& routing_table();

View file

@ -100,7 +100,9 @@ ErrorOr<void> Socket::setsockopt(int level, int option, Userspace<void const*> u
auto device = NetworkingManagement::the().lookup_by_name(ifname->view()); auto device = NetworkingManagement::the().lookup_by_name(ifname->view());
if (!device) if (!device)
return ENODEV; return ENODEV;
m_bound_interface = move(device); m_bound_interface.with([&device](auto& bound_device) {
bound_device = move(device);
});
return {}; return {};
} }
case SO_DEBUG: case SO_DEBUG:
@ -169,31 +171,35 @@ ErrorOr<void> Socket::getsockopt(OpenFileDescription&, int level, int option, Us
case SO_ERROR: { case SO_ERROR: {
if (size < sizeof(int)) if (size < sizeof(int))
return EINVAL; return EINVAL;
int errno = 0; return so_error().with([&size, value, value_size](auto& error) -> ErrorOr<void> {
if (auto const& error = so_error(); error.has_value()) int errno = 0;
errno = error.value(); if (error.has_value())
TRY(copy_to_user(static_ptr_cast<int*>(value), &errno)); errno = error.value();
size = sizeof(int); TRY(copy_to_user(static_ptr_cast<int*>(value), &errno));
TRY(copy_to_user(value_size, &size)); size = sizeof(int);
clear_so_error(); TRY(copy_to_user(value_size, &size));
return {}; error = {};
return {};
});
} }
case SO_BINDTODEVICE: case SO_BINDTODEVICE:
if (size < IFNAMSIZ) if (size < IFNAMSIZ)
return EINVAL; return EINVAL;
if (m_bound_interface) { return m_bound_interface.with([&](auto& bound_device) -> ErrorOr<void> {
auto name = m_bound_interface->name(); if (bound_device) {
auto length = name.length() + 1; auto name = bound_device->name();
auto characters = name.characters_without_null_termination(); auto length = name.length() + 1;
TRY(copy_to_user(static_ptr_cast<char*>(value), characters, length)); auto characters = name.characters_without_null_termination();
size = length; TRY(copy_to_user(static_ptr_cast<char*>(value), characters, length));
return copy_to_user(value_size, &size); size = length;
} else { return copy_to_user(value_size, &size);
size = 0; } else {
TRY(copy_to_user(value_size, &size)); size = 0;
// FIXME: This return value looks suspicious. TRY(copy_to_user(value_size, &size));
return EFAULT; // FIXME: This return value looks suspicious.
} return EFAULT;
}
});
case SO_TIMESTAMP: case SO_TIMESTAMP:
if (size < sizeof(int)) if (size < sizeof(int))
return EINVAL; return EINVAL;

View file

@ -7,9 +7,9 @@
#pragma once #pragma once
#include <AK/Error.h> #include <AK/Error.h>
#include <AK/RefPtr.h>
#include <AK/Time.h> #include <AK/Time.h>
#include <Kernel/FileSystem/File.h> #include <Kernel/FileSystem/File.h>
#include <Kernel/Library/LockRefPtr.h>
#include <Kernel/Locking/Mutex.h> #include <Kernel/Locking/Mutex.h>
#include <Kernel/Net/NetworkAdapter.h> #include <Kernel/Net/NetworkAdapter.h>
#include <Kernel/UnixTypes.h> #include <Kernel/UnixTypes.h>
@ -90,7 +90,7 @@ public:
ProcessID acceptor_pid() const { return m_acceptor.pid; } ProcessID acceptor_pid() const { return m_acceptor.pid; }
UserID acceptor_uid() const { return m_acceptor.uid; } UserID acceptor_uid() const { return m_acceptor.uid; }
GroupID acceptor_gid() const { return m_acceptor.gid; } GroupID acceptor_gid() const { return m_acceptor.gid; }
LockRefPtr<NetworkAdapter> const bound_interface() const { return m_bound_interface; } SpinlockProtected<RefPtr<NetworkAdapter>, LockRank::None> const& bound_interface() const { return m_bound_interface; }
Mutex& mutex() { return m_mutex; } Mutex& mutex() { return m_mutex; }
@ -123,31 +123,29 @@ protected:
Role m_role { Role::None }; Role m_role { Role::None };
Optional<ErrnoCode> const& so_error() const SpinlockProtected<Optional<ErrnoCode>, LockRank::None>& so_error() { return m_so_error; }
{
VERIFY(m_mutex.is_exclusively_locked_by_current_thread());
return m_so_error;
}
Error set_so_error(ErrnoCode error_code) Error set_so_error(ErrnoCode error_code)
{ {
MutexLocker locker(mutex()); m_so_error.with([&error_code](auto& so_error) {
m_so_error = error_code; so_error = error_code;
});
return Error::from_errno(error_code); return Error::from_errno(error_code);
} }
Error set_so_error(Error error) Error set_so_error(Error error)
{ {
MutexLocker locker(mutex()); m_so_error.with([&error](auto& so_error) {
m_so_error = static_cast<ErrnoCode>(error.code()); so_error = static_cast<ErrnoCode>(error.code());
});
return error; return error;
} }
void clear_so_error() void clear_so_error()
{ {
m_so_error = {}; m_so_error.with([](auto& so_error) {
so_error = {};
});
} }
void set_origin(Process const&); void set_origin(Process const&);
@ -173,13 +171,13 @@ private:
bool m_shut_down_for_reading { false }; bool m_shut_down_for_reading { false };
bool m_shut_down_for_writing { false }; bool m_shut_down_for_writing { false };
LockRefPtr<NetworkAdapter> m_bound_interface { nullptr }; SpinlockProtected<RefPtr<NetworkAdapter>, LockRank::None> m_bound_interface;
Time m_receive_timeout {}; Time m_receive_timeout {};
Time m_send_timeout {}; Time m_send_timeout {};
int m_timestamp { 0 }; int m_timestamp { 0 };
Optional<ErrnoCode> m_so_error; SpinlockProtected<Optional<ErrnoCode>, LockRank::None> m_so_error;
Vector<NonnullRefPtr<Socket>> m_pending; Vector<NonnullRefPtr<Socket>> m_pending;
}; };

View file

@ -202,7 +202,8 @@ ErrorOr<size_t> TCPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserO
ErrorOr<size_t> TCPSocket::protocol_send(UserOrKernelBuffer const& data, size_t data_length) ErrorOr<size_t> TCPSocket::protocol_send(UserOrKernelBuffer const& data, size_t data_length)
{ {
RoutingDecision routing_decision = route_to(peer_address(), local_address(), bound_interface()); auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr<NetworkAdapter> { return bound_device; });
RoutingDecision routing_decision = route_to(peer_address(), local_address(), adapter);
if (routing_decision.is_zero()) if (routing_decision.is_zero())
return set_so_error(EHOSTUNREACH); return set_so_error(EHOSTUNREACH);
size_t mss = routing_decision.adapter->mtu() - sizeof(IPv4Packet) - sizeof(TCPPacket); size_t mss = routing_decision.adapter->mtu() - sizeof(IPv4Packet) - sizeof(TCPPacket);
@ -220,7 +221,8 @@ ErrorOr<void> TCPSocket::send_ack(bool allow_duplicate)
ErrorOr<void> TCPSocket::send_tcp_packet(u16 flags, UserOrKernelBuffer const* payload, size_t payload_size, RoutingDecision* user_routing_decision) ErrorOr<void> TCPSocket::send_tcp_packet(u16 flags, UserOrKernelBuffer const* payload, size_t payload_size, RoutingDecision* user_routing_decision)
{ {
RoutingDecision routing_decision = user_routing_decision ? *user_routing_decision : route_to(peer_address(), local_address(), bound_interface()); auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr<NetworkAdapter> { return bound_device; });
RoutingDecision routing_decision = user_routing_decision ? *user_routing_decision : route_to(peer_address(), local_address(), adapter);
if (routing_decision.is_zero()) if (routing_decision.is_zero())
return set_so_error(EHOSTUNREACH); return set_so_error(EHOSTUNREACH);
@ -409,13 +411,14 @@ NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(IPv4Address const& source, I
ErrorOr<void> TCPSocket::protocol_bind() ErrorOr<void> TCPSocket::protocol_bind()
{ {
if (has_specific_local_address() && !m_adapter) { return m_adapter.with([this](auto& adapter) -> ErrorOr<void> {
m_adapter = NetworkingManagement::the().from_ipv4_address(local_address()); if (has_specific_local_address() && !adapter) {
if (!m_adapter) adapter = NetworkingManagement::the().from_ipv4_address(local_address());
return set_so_error(EADDRNOTAVAIL); if (!adapter)
} return set_so_error(EADDRNOTAVAIL);
}
return {}; return {};
});
} }
ErrorOr<void> TCPSocket::protocol_listen(bool did_allocate_port) ErrorOr<void> TCPSocket::protocol_listen(bool did_allocate_port)
@ -598,7 +601,8 @@ void TCPSocket::retransmit_packets()
return; return;
} }
auto routing_decision = route_to(peer_address(), local_address(), bound_interface()); auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr<NetworkAdapter> { return bound_device; });
auto routing_decision = route_to(peer_address(), local_address(), adapter);
if (routing_decision.is_zero()) if (routing_decision.is_zero())
return; return;

View file

@ -189,7 +189,7 @@ private:
HashMap<IPv4SocketTuple, NonnullRefPtr<TCPSocket>> m_pending_release_for_accept; HashMap<IPv4SocketTuple, NonnullRefPtr<TCPSocket>> m_pending_release_for_accept;
Direction m_direction { Direction::Unspecified }; Direction m_direction { Direction::Unspecified };
Error m_error { Error::None }; Error m_error { Error::None };
LockRefPtr<NetworkAdapter> m_adapter; SpinlockProtected<RefPtr<NetworkAdapter>, LockRank::None> m_adapter;
u32 m_sequence_number { 0 }; u32 m_sequence_number { 0 };
u32 m_ack_number { 0 }; u32 m_ack_number { 0 };
State m_state { State::Closed }; State m_state { State::Closed };
@ -200,7 +200,7 @@ private:
struct OutgoingPacket { struct OutgoingPacket {
u32 ack_number { 0 }; u32 ack_number { 0 };
LockRefPtr<PacketWithTimestamp> buffer; RefPtr<PacketWithTimestamp> buffer;
size_t ipv4_payload_offset; size_t ipv4_payload_offset;
LockWeakPtr<NetworkAdapter> adapter; LockWeakPtr<NetworkAdapter> adapter;
int tx_counter { 0 }; int tx_counter { 0 };

View file

@ -84,7 +84,8 @@ ErrorOr<size_t> UDPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserO
ErrorOr<size_t> UDPSocket::protocol_send(UserOrKernelBuffer const& data, size_t data_length) ErrorOr<size_t> UDPSocket::protocol_send(UserOrKernelBuffer const& data, size_t data_length)
{ {
auto routing_decision = route_to(peer_address(), local_address(), bound_interface()); auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr<NetworkAdapter> { return bound_device; });
auto routing_decision = route_to(peer_address(), local_address(), adapter);
if (routing_decision.is_zero()) if (routing_decision.is_zero())
return set_so_error(EHOSTUNREACH); return set_so_error(EHOSTUNREACH);
auto ipv4_payload_offset = routing_decision.adapter->ipv4_payload_offset(); auto ipv4_payload_offset = routing_decision.adapter->ipv4_payload_offset();