From 54ceabd48d49fdc44a70a2ff66f98d9bb1ff2ee2 Mon Sep 17 00:00:00 2001 From: Conrad Pankoff Date: Fri, 9 Aug 2019 12:34:32 +1000 Subject: [PATCH] Kernel: Use WeakPtr instead of NetworkAdapter* in net code --- Kernel/Net/IPv4Socket.cpp | 2 +- Kernel/Net/NetworkAdapter.cpp | 4 ++-- Kernel/Net/NetworkAdapter.h | 6 ++++-- Kernel/Net/NetworkTask.cpp | 10 +++++----- Kernel/Net/Routing.cpp | 4 ++-- Kernel/Net/Routing.h | 2 +- Kernel/Net/TCPSocket.cpp | 11 ++++++++++- Kernel/Net/TCPSocket.h | 3 ++- Kernel/Net/UDPSocket.cpp | 2 +- 9 files changed, 28 insertions(+), 16 deletions(-) diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index a97005bb11..f33e90224c 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -169,7 +169,7 @@ ssize_t IPv4Socket::sendto(FileDescription&, const void* data, size_t data_lengt m_peer_port = ntohs(ia.sin_port); } - auto* adapter = adapter_for_route_to(m_peer_address); + auto adapter = adapter_for_route_to(m_peer_address); if (!adapter) return -EHOSTUNREACH; diff --git a/Kernel/Net/NetworkAdapter.cpp b/Kernel/Net/NetworkAdapter.cpp index f77804a914..14ed365f1a 100644 --- a/Kernel/Net/NetworkAdapter.cpp +++ b/Kernel/Net/NetworkAdapter.cpp @@ -22,12 +22,12 @@ void NetworkAdapter::for_each(Function callback) callback(*it); } -NetworkAdapter* NetworkAdapter::from_ipv4_address(const IPv4Address& address) +WeakPtr NetworkAdapter::from_ipv4_address(const IPv4Address& address) { LOCKER(all_adapters().lock()); for (auto* adapter : all_adapters().resource()) { if (adapter->ipv4_address() == address) - return adapter; + return adapter->make_weak_ptr(); } return nullptr; } diff --git a/Kernel/Net/NetworkAdapter.h b/Kernel/Net/NetworkAdapter.h index c1d99b3f9e..4860e8b647 100644 --- a/Kernel/Net/NetworkAdapter.h +++ b/Kernel/Net/NetworkAdapter.h @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include #include #include @@ -12,10 +14,10 @@ class NetworkAdapter; -class NetworkAdapter { +class NetworkAdapter : public Weakable { public: static void for_each(Function); - static NetworkAdapter* from_ipv4_address(const IPv4Address&); + static WeakPtr from_ipv4_address(const IPv4Address&); virtual ~NetworkAdapter(); virtual const char* class_name() const = 0; diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp index c57562b1aa..38bd3b93d4 100644 --- a/Kernel/Net/NetworkTask.cpp +++ b/Kernel/Net/NetworkTask.cpp @@ -38,7 +38,7 @@ void NetworkTask_main() { LoopbackAdapter::the(); - auto* adapter = E1000NetworkAdapter::the(); + auto adapter = E1000NetworkAdapter::the(); if (!adapter) dbgprintf("E1000 network card not found!\n"); @@ -150,7 +150,7 @@ void handle_arp(const EthernetFrameHeader& eth, int frame_size) if (packet.operation() == ARPOperation::Request) { // Who has this IP address? - if (auto* adapter = NetworkAdapter::from_ipv4_address(packet.target_protocol_address())) { + if (auto adapter = NetworkAdapter::from_ipv4_address(packet.target_protocol_address())) { // We do! kprintf("handle_arp: Responding to ARP request for my IPv4 address (%s)\n", adapter->ipv4_address().to_string().characters()); @@ -231,7 +231,7 @@ void handle_icmp(const EthernetFrameHeader& eth, int frame_size) } } - auto* adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); + auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); if (!adapter) return; @@ -260,7 +260,7 @@ void handle_udp(const EthernetFrameHeader& eth, int frame_size) (void)frame_size; auto& ipv4_packet = *static_cast(eth.payload()); - auto* adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); + auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); if (!adapter) { kprintf("handle_udp: this packet is not for me, it's for %s\n", ipv4_packet.destination().to_string().characters()); return; @@ -292,7 +292,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) (void)frame_size; auto& ipv4_packet = *static_cast(eth.payload()); - auto* adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); + auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); if (!adapter) { kprintf("handle_tcp: this packet is not for me, it's for %s\n", ipv4_packet.destination().to_string().characters()); return; diff --git a/Kernel/Net/Routing.cpp b/Kernel/Net/Routing.cpp index b2c3d8a782..c9bf54f67e 100644 --- a/Kernel/Net/Routing.cpp +++ b/Kernel/Net/Routing.cpp @@ -1,10 +1,10 @@ #include #include -NetworkAdapter* adapter_for_route_to(const IPv4Address& ipv4_address) +WeakPtr adapter_for_route_to(const IPv4Address& ipv4_address) { // FIXME: Have an actual routing table. if (ipv4_address == IPv4Address(127, 0, 0, 1)) - return &LoopbackAdapter::the(); + return LoopbackAdapter::the().make_weak_ptr(); return NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2)); } diff --git a/Kernel/Net/Routing.h b/Kernel/Net/Routing.h index 48143520eb..0feed2cf64 100644 --- a/Kernel/Net/Routing.h +++ b/Kernel/Net/Routing.h @@ -2,4 +2,4 @@ #include -NetworkAdapter* adapter_for_route_to(const IPv4Address&); +WeakPtr adapter_for_route_to(const IPv4Address&); diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index b6c1b726af..df998389f4 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -80,7 +80,16 @@ int TCPSocket::protocol_send(const void* data, int data_length) void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size) { - ASSERT(m_adapter); + if (!m_adapter) { + if (has_specific_local_address()) { + m_adapter = NetworkAdapter::from_ipv4_address(local_address()); + } else { + m_adapter = adapter_for_route_to(peer_address()); + if (m_adapter) + set_local_address(m_adapter->ipv4_address()); + } + } + ASSERT(!!m_adapter); auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size); auto& tcp_packet = *(TCPPacket*)(buffer.pointer()); diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index c6450e850b..48a47807b3 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include class TCPSocket final : public IPv4Socket { @@ -86,7 +87,7 @@ private: virtual KResult protocol_bind() override; virtual KResult protocol_listen() override; - NetworkAdapter* m_adapter { nullptr }; + WeakPtr m_adapter; u32 m_sequence_number { 0 }; u32 m_ack_number { 0 }; State m_state { State::Closed }; diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index c3acf1261f..7ceebbb1d9 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -56,7 +56,7 @@ int UDPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size int UDPSocket::protocol_send(const void* data, int data_length) { - auto* adapter = adapter_for_route_to(peer_address()); + auto adapter = adapter_for_route_to(peer_address()); if (!adapter) return -EHOSTUNREACH; auto buffer = ByteBuffer::create_zeroed(sizeof(UDPPacket) + data_length);