diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 67ab93c33d..61bdb59441 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -212,7 +212,8 @@ ErrorOr IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer cons return set_so_error(EPIPE); auto allow_using_gateway = ((flags & MSG_DONTROUTE) || m_routing_disabled) ? AllowUsingGateway::No : AllowUsingGateway::Yes; - auto routing_decision = route_to(m_peer_address, m_local_address, bound_interface(), allow_using_gateway); + auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr { return bound_device; }); + auto routing_decision = route_to(m_peer_address, m_local_address, adapter, allow_using_gateway); if (routing_decision.is_zero()) return set_so_error(EHOSTUNREACH); diff --git a/Kernel/Net/NetworkAdapter.cpp b/Kernel/Net/NetworkAdapter.cpp index b1139ba8ad..9ccfd88d84 100644 --- a/Kernel/Net/NetworkAdapter.cpp +++ b/Kernel/Net/NetworkAdapter.cpp @@ -111,9 +111,9 @@ size_t NetworkAdapter::dequeue_packet(u8* buffer, size_t buffer_size, Time& pack return packet_size; } -LockRefPtr NetworkAdapter::acquire_packet_buffer(size_t size) +RefPtr NetworkAdapter::acquire_packet_buffer(size_t size) { - auto packet = m_unused_packets.with([size](auto& unused_packets) -> LockRefPtr { + auto packet = m_unused_packets.with([size](auto& unused_packets) -> RefPtr { if (unused_packets.is_empty()) return nullptr; @@ -135,7 +135,7 @@ LockRefPtr NetworkAdapter::acquire_packet_buffer(size_t siz auto buffer_or_error = KBuffer::try_create_with_size("NetworkAdapter: Packet buffer"sv, size, Memory::Region::Access::ReadWrite, AllocationStrategy::AllocateNow); if (buffer_or_error.is_error()) return {}; - packet = adopt_lock_ref_if_nonnull(new (nothrow) PacketWithTimestamp { buffer_or_error.release_value(), kgettimeofday() }); + packet = adopt_ref_if_nonnull(new (nothrow) PacketWithTimestamp { buffer_or_error.release_value(), kgettimeofday() }); if (!packet) return {}; packet->buffer->set_size(size); diff --git a/Kernel/Net/NetworkAdapter.h b/Kernel/Net/NetworkAdapter.h index bfc9314fd7..6fd6cabbd3 100644 --- a/Kernel/Net/NetworkAdapter.h +++ b/Kernel/Net/NetworkAdapter.h @@ -39,7 +39,7 @@ struct PacketWithTimestamp final : public AtomicRefCounted NonnullOwnPtr buffer; Time timestamp; - IntrusiveListNode> packet_node; + IntrusiveListNode> packet_node; }; class NetworkingManagement; @@ -91,7 +91,7 @@ public: u32 packets_out() const { return m_packets_out; } u32 bytes_out() const { return m_bytes_out; } - LockRefPtr acquire_packet_buffer(size_t); + RefPtr acquire_packet_buffer(size_t); void release_packet_buffer(PacketWithTimestamp&); constexpr size_t layer3_payload_offset() const { return sizeof(EthernetFrameHeader); } diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp index b5de06dda8..a89860452c 100644 --- a/Kernel/Net/NetworkTask.cpp +++ b/Kernel/Net/NetworkTask.cpp @@ -31,7 +31,7 @@ static void handle_icmp(EthernetFrameHeader const&, IPv4Packet const&, Time cons static void handle_udp(IPv4Packet const&, Time const& packet_timestamp); static void handle_tcp(IPv4Packet const&, Time const& packet_timestamp); static void send_delayed_tcp_ack(TCPSocket& socket); -static void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, LockRefPtr adapter); +static void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, RefPtr adapter); static void flush_delayed_tcp_acks(); static void retransmit_tcp_packets(); @@ -333,7 +333,7 @@ void flush_delayed_tcp_acks() } } -void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, LockRefPtr adapter) +void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, RefPtr adapter) { auto routing_decision = route_to(ipv4_packet.source(), ipv4_packet.destination(), adapter); if (routing_decision.is_zero()) diff --git a/Kernel/Net/NetworkingManagement.cpp b/Kernel/Net/NetworkingManagement.cpp index 70958c28fc..111696cbbd 100644 --- a/Kernel/Net/NetworkingManagement.cpp +++ b/Kernel/Net/NetworkingManagement.cpp @@ -35,7 +35,7 @@ UNMAP_AFTER_INIT NetworkingManagement::NetworkingManagement() { } -NonnullLockRefPtr NetworkingManagement::loopback_adapter() const +NonnullRefPtr NetworkingManagement::loopback_adapter() const { return *m_loopback_adapter; } @@ -56,13 +56,13 @@ ErrorOr NetworkingManagement::try_for_each(Function(NetworkA }); } -LockRefPtr NetworkingManagement::from_ipv4_address(IPv4Address const& address) const +RefPtr NetworkingManagement::from_ipv4_address(IPv4Address const& address) const { if (address[0] == 0 && address[1] == 0 && address[2] == 0 && address[3] == 0) return m_loopback_adapter; if (address[0] == 127) return m_loopback_adapter; - return m_adapters.with([&](auto& adapters) -> LockRefPtr { + return m_adapters.with([&](auto& adapters) -> RefPtr { for (auto& adapter : adapters) { if (adapter->ipv4_address() == address || adapter->ipv4_broadcast() == address) return adapter; @@ -71,9 +71,9 @@ LockRefPtr NetworkingManagement::from_ipv4_address(IPv4Address c }); } -LockRefPtr NetworkingManagement::lookup_by_name(StringView name) const +RefPtr NetworkingManagement::lookup_by_name(StringView name) const { - return m_adapters.with([&](auto& adapters) -> LockRefPtr { + return m_adapters.with([&](auto& adapters) -> RefPtr { for (auto& adapter : adapters) { if (adapter->name() == name) return adapter; diff --git a/Kernel/Net/NetworkingManagement.h b/Kernel/Net/NetworkingManagement.h index 587773030a..ad18a7513e 100644 --- a/Kernel/Net/NetworkingManagement.h +++ b/Kernel/Net/NetworkingManagement.h @@ -8,9 +8,9 @@ #include #include +#include #include #include -#include #include #include #include @@ -33,16 +33,16 @@ public: void for_each(Function); ErrorOr try_for_each(Function(NetworkAdapter&)>); - LockRefPtr from_ipv4_address(IPv4Address const&) const; - LockRefPtr lookup_by_name(StringView) const; + RefPtr from_ipv4_address(IPv4Address const&) const; + RefPtr lookup_by_name(StringView) const; - NonnullLockRefPtr loopback_adapter() const; + NonnullRefPtr loopback_adapter() const; private: ErrorOr> determine_network_device(PCI::DeviceIdentifier const&) const; - SpinlockProtected>, LockRank::None> m_adapters {}; - LockRefPtr m_loopback_adapter; + SpinlockProtected>, LockRank::None> m_adapters {}; + RefPtr m_loopback_adapter; }; } diff --git a/Kernel/Net/Routing.cpp b/Kernel/Net/Routing.cpp index 81ca0b000c..9ceaa71bc5 100644 --- a/Kernel/Net/Routing.cpp +++ b/Kernel/Net/Routing.cpp @@ -135,11 +135,11 @@ SpinlockProtected& routing_table() return *s_routing_table; } -ErrorOr update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, LockRefPtr adapter, UpdateTable update) +ErrorOr update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, RefPtr adapter, UpdateTable update) { dbgln_if(ROUTING_DEBUG, "update_routing_table {} {} {} {} {} {}", destination, gateway, netmask, flags, adapter, update == UpdateTable::Set ? "Set" : "Delete"); - auto route_entry = adopt_lock_ref_if_nonnull(new (nothrow) Route { destination, gateway, netmask, flags, adapter.release_nonnull() }); + auto route_entry = adopt_ref_if_nonnull(new (nothrow) Route { destination, gateway, netmask, flags, adapter.release_nonnull() }); if (!route_entry) return ENOMEM; @@ -178,7 +178,7 @@ static MACAddress multicast_ethernet_address(IPv4Address const& address) return MACAddress { 0x01, 0x00, 0x5e, (u8)(address[1] & 0x7f), address[2], address[3] }; } -RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, LockRefPtr const through, AllowUsingGateway allow_using_gateway) +RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, RefPtr const through, AllowUsingGateway allow_using_gateway) { auto matches = [&](auto& adapter) { if (!through) @@ -200,8 +200,8 @@ RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, L auto target_addr = target.to_u32(); auto source_addr = source.to_u32(); - LockRefPtr local_adapter = nullptr; - LockRefPtr chosen_route = nullptr; + RefPtr local_adapter = nullptr; + RefPtr chosen_route = nullptr; NetworkingManagement::the().for_each([source_addr, &target_addr, &local_adapter, &matches, &through](NetworkAdapter& adapter) { auto adapter_addr = adapter.ipv4_address().to_u32(); @@ -263,7 +263,7 @@ RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, L return { nullptr, {} }; } - LockRefPtr adapter = nullptr; + RefPtr adapter = nullptr; IPv4Address next_hop_ip; if (local_adapter) { diff --git a/Kernel/Net/Routing.h b/Kernel/Net/Routing.h index 8732184901..2ad6a08d90 100644 --- a/Kernel/Net/Routing.h +++ b/Kernel/Net/Routing.h @@ -7,7 +7,7 @@ #pragma once #include -#include +#include #include #include #include @@ -15,7 +15,7 @@ namespace Kernel { struct Route final : public AtomicRefCounted { - Route(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, NonnullLockRefPtr adapter) + Route(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, NonnullRefPtr adapter) : destination(destination) , gateway(gateway) , netmask(netmask) @@ -38,14 +38,14 @@ struct Route final : public AtomicRefCounted { const IPv4Address gateway; const IPv4Address netmask; const u16 flags; - NonnullLockRefPtr adapter; + NonnullRefPtr const adapter; - IntrusiveListNode> route_list_node {}; + IntrusiveListNode> route_list_node {}; using RouteList = IntrusiveList<&Route::route_list_node>; }; struct RoutingDecision { - LockRefPtr adapter; + RefPtr adapter; MACAddress next_hop; bool is_zero() const; @@ -57,14 +57,14 @@ enum class UpdateTable { }; void update_arp_table(IPv4Address const&, MACAddress const&, UpdateTable update); -ErrorOr update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, LockRefPtr const adapter, UpdateTable update); +ErrorOr update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, RefPtr const adapter, UpdateTable update); enum class AllowUsingGateway { Yes, No, }; -RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, LockRefPtr const through = nullptr, AllowUsingGateway = AllowUsingGateway::Yes); +RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, RefPtr const through = nullptr, AllowUsingGateway = AllowUsingGateway::Yes); SpinlockProtected, LockRank::None>& arp_table(); SpinlockProtected& routing_table(); diff --git a/Kernel/Net/Socket.cpp b/Kernel/Net/Socket.cpp index 9ae1585b94..f5e8357a70 100644 --- a/Kernel/Net/Socket.cpp +++ b/Kernel/Net/Socket.cpp @@ -100,7 +100,9 @@ ErrorOr Socket::setsockopt(int level, int option, Userspace u auto device = NetworkingManagement::the().lookup_by_name(ifname->view()); if (!device) return ENODEV; - m_bound_interface = move(device); + m_bound_interface.with([&device](auto& bound_device) { + bound_device = move(device); + }); return {}; } case SO_DEBUG: @@ -169,31 +171,35 @@ ErrorOr Socket::getsockopt(OpenFileDescription&, int level, int option, Us case SO_ERROR: { if (size < sizeof(int)) return EINVAL; - int errno = 0; - if (auto const& error = so_error(); error.has_value()) - errno = error.value(); - TRY(copy_to_user(static_ptr_cast(value), &errno)); - size = sizeof(int); - TRY(copy_to_user(value_size, &size)); - clear_so_error(); - return {}; + return so_error().with([&size, value, value_size](auto& error) -> ErrorOr { + int errno = 0; + if (error.has_value()) + errno = error.value(); + TRY(copy_to_user(static_ptr_cast(value), &errno)); + size = sizeof(int); + TRY(copy_to_user(value_size, &size)); + error = {}; + return {}; + }); } case SO_BINDTODEVICE: if (size < IFNAMSIZ) return EINVAL; - if (m_bound_interface) { - auto name = m_bound_interface->name(); - auto length = name.length() + 1; - auto characters = name.characters_without_null_termination(); - TRY(copy_to_user(static_ptr_cast(value), characters, length)); - size = length; - return copy_to_user(value_size, &size); - } else { - size = 0; - TRY(copy_to_user(value_size, &size)); - // FIXME: This return value looks suspicious. - return EFAULT; - } + return m_bound_interface.with([&](auto& bound_device) -> ErrorOr { + if (bound_device) { + auto name = bound_device->name(); + auto length = name.length() + 1; + auto characters = name.characters_without_null_termination(); + TRY(copy_to_user(static_ptr_cast(value), characters, length)); + size = length; + return copy_to_user(value_size, &size); + } else { + size = 0; + TRY(copy_to_user(value_size, &size)); + // FIXME: This return value looks suspicious. + return EFAULT; + } + }); case SO_TIMESTAMP: if (size < sizeof(int)) return EINVAL; diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h index 746274e0ff..ff071d2d83 100644 --- a/Kernel/Net/Socket.h +++ b/Kernel/Net/Socket.h @@ -7,9 +7,9 @@ #pragma once #include +#include #include #include -#include #include #include #include @@ -90,7 +90,7 @@ public: ProcessID acceptor_pid() const { return m_acceptor.pid; } UserID acceptor_uid() const { return m_acceptor.uid; } GroupID acceptor_gid() const { return m_acceptor.gid; } - LockRefPtr const bound_interface() const { return m_bound_interface; } + SpinlockProtected, LockRank::None> const& bound_interface() const { return m_bound_interface; } Mutex& mutex() { return m_mutex; } @@ -123,31 +123,29 @@ protected: Role m_role { Role::None }; - Optional const& so_error() const - { - VERIFY(m_mutex.is_exclusively_locked_by_current_thread()); - return m_so_error; - } + SpinlockProtected, LockRank::None>& so_error() { return m_so_error; } Error set_so_error(ErrnoCode error_code) { - MutexLocker locker(mutex()); - m_so_error = error_code; - + m_so_error.with([&error_code](auto& so_error) { + so_error = error_code; + }); return Error::from_errno(error_code); } Error set_so_error(Error error) { - MutexLocker locker(mutex()); - m_so_error = static_cast(error.code()); - + m_so_error.with([&error](auto& so_error) { + so_error = static_cast(error.code()); + }); return error; } void clear_so_error() { - m_so_error = {}; + m_so_error.with([](auto& so_error) { + so_error = {}; + }); } void set_origin(Process const&); @@ -173,13 +171,13 @@ private: bool m_shut_down_for_reading { false }; bool m_shut_down_for_writing { false }; - LockRefPtr m_bound_interface { nullptr }; + SpinlockProtected, LockRank::None> m_bound_interface; Time m_receive_timeout {}; Time m_send_timeout {}; int m_timestamp { 0 }; - Optional m_so_error; + SpinlockProtected, LockRank::None> m_so_error; Vector> m_pending; }; diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index f238c47a7d..c42e9c6b65 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -202,7 +202,8 @@ ErrorOr TCPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserO ErrorOr TCPSocket::protocol_send(UserOrKernelBuffer const& data, size_t data_length) { - RoutingDecision routing_decision = route_to(peer_address(), local_address(), bound_interface()); + auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr { return bound_device; }); + RoutingDecision routing_decision = route_to(peer_address(), local_address(), adapter); if (routing_decision.is_zero()) return set_so_error(EHOSTUNREACH); size_t mss = routing_decision.adapter->mtu() - sizeof(IPv4Packet) - sizeof(TCPPacket); @@ -220,7 +221,8 @@ ErrorOr TCPSocket::send_ack(bool allow_duplicate) ErrorOr TCPSocket::send_tcp_packet(u16 flags, UserOrKernelBuffer const* payload, size_t payload_size, RoutingDecision* user_routing_decision) { - RoutingDecision routing_decision = user_routing_decision ? *user_routing_decision : route_to(peer_address(), local_address(), bound_interface()); + auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr { return bound_device; }); + RoutingDecision routing_decision = user_routing_decision ? *user_routing_decision : route_to(peer_address(), local_address(), adapter); if (routing_decision.is_zero()) return set_so_error(EHOSTUNREACH); @@ -409,13 +411,14 @@ NetworkOrdered TCPSocket::compute_tcp_checksum(IPv4Address const& source, I ErrorOr TCPSocket::protocol_bind() { - if (has_specific_local_address() && !m_adapter) { - m_adapter = NetworkingManagement::the().from_ipv4_address(local_address()); - if (!m_adapter) - return set_so_error(EADDRNOTAVAIL); - } - - return {}; + return m_adapter.with([this](auto& adapter) -> ErrorOr { + if (has_specific_local_address() && !adapter) { + adapter = NetworkingManagement::the().from_ipv4_address(local_address()); + if (!adapter) + return set_so_error(EADDRNOTAVAIL); + } + return {}; + }); } ErrorOr TCPSocket::protocol_listen(bool did_allocate_port) @@ -598,7 +601,8 @@ void TCPSocket::retransmit_packets() return; } - auto routing_decision = route_to(peer_address(), local_address(), bound_interface()); + auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr { return bound_device; }); + auto routing_decision = route_to(peer_address(), local_address(), adapter); if (routing_decision.is_zero()) return; diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index ae73102832..a949a0e92b 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -189,7 +189,7 @@ private: HashMap> m_pending_release_for_accept; Direction m_direction { Direction::Unspecified }; Error m_error { Error::None }; - LockRefPtr m_adapter; + SpinlockProtected, LockRank::None> m_adapter; u32 m_sequence_number { 0 }; u32 m_ack_number { 0 }; State m_state { State::Closed }; @@ -200,7 +200,7 @@ private: struct OutgoingPacket { u32 ack_number { 0 }; - LockRefPtr buffer; + RefPtr buffer; size_t ipv4_payload_offset; LockWeakPtr adapter; int tx_counter { 0 }; diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index 88fc3388e2..2372422cc2 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -84,7 +84,8 @@ ErrorOr UDPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserO ErrorOr UDPSocket::protocol_send(UserOrKernelBuffer const& data, size_t data_length) { - auto routing_decision = route_to(peer_address(), local_address(), bound_interface()); + auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr { return bound_device; }); + auto routing_decision = route_to(peer_address(), local_address(), adapter); if (routing_decision.is_zero()) return set_so_error(EHOSTUNREACH); auto ipv4_payload_offset = routing_decision.adapter->ipv4_payload_offset();