diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index d5038d0a86..1cabc22551 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -121,8 +121,9 @@ KResult IPv4Socket::bind(Userspace user_address, socklen_t addr KResult IPv4Socket::listen(size_t backlog) { Locker locker(lock()); - if (auto result = allocate_local_port_if_needed(); result.is_error() && result.error() != -ENOPROTOOPT) - return result.error(); + auto result = allocate_local_port_if_needed(); + if (result.error_or_port.is_error() && result.error_or_port.error() != -ENOPROTOOPT) + return result.error_or_port.error(); set_backlog(backlog); m_role = Role::Listener; @@ -130,7 +131,7 @@ KResult IPv4Socket::listen(size_t backlog) dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket({}) listening with backlog={}", this, backlog); - return protocol_listen(); + return protocol_listen(result.did_allocate); } KResult IPv4Socket::connect(FileDescription& description, Userspace address, socklen_t address_size, ShouldBlock should_block) @@ -172,16 +173,16 @@ bool IPv4Socket::can_write(const FileDescription&, size_t) const return is_connected(); } -KResultOr IPv4Socket::allocate_local_port_if_needed() +PortAllocationResult IPv4Socket::allocate_local_port_if_needed() { Locker locker(lock()); if (m_local_port) - return m_local_port; + return { m_local_port, false }; auto port_or_error = protocol_allocate_local_port(); if (port_or_error.is_error()) - return port_or_error.error(); + return { port_or_error.error(), false }; m_local_port = port_or_error.value(); - return port_or_error.value(); + return { m_local_port, true }; } KResultOr IPv4Socket::sendto(FileDescription&, const UserOrKernelBuffer& data, size_t data_length, [[maybe_unused]] int flags, Userspace addr, socklen_t addr_length) @@ -212,8 +213,8 @@ KResultOr IPv4Socket::sendto(FileDescription&, const UserOrKernelBuffer& if (m_local_address.to_u32() == 0) m_local_address = routing_decision.adapter->ipv4_address(); - if (auto result = allocate_local_port_if_needed(); result.is_error() && result.error() != -ENOPROTOOPT) - return result.error(); + if (auto result = allocate_local_port_if_needed(); result.error_or_port.is_error() && result.error_or_port.error() != -ENOPROTOOPT) + return result.error_or_port.error(); dbgln_if(IPV4_SOCKET_DEBUG, "sendto: destination={}:{}", m_peer_address, m_peer_port); diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h index 2ec52a1860..f089cb409f 100644 --- a/Kernel/Net/IPv4Socket.h +++ b/Kernel/Net/IPv4Socket.h @@ -21,6 +21,11 @@ class NetworkAdapter; class TCPPacket; class TCPSocket; +struct PortAllocationResult { + KResultOr error_or_port; + bool did_allocate; +}; + class IPv4Socket : public Socket { public: static KResultOr> create(int type, int protocol); @@ -72,10 +77,10 @@ protected: IPv4Socket(int type, int protocol); virtual const char* class_name() const override { return "IPv4Socket"; } - KResultOr allocate_local_port_if_needed(); + PortAllocationResult allocate_local_port_if_needed(); virtual KResult protocol_bind() { return KSuccess; } - virtual KResult protocol_listen() { return KSuccess; } + virtual KResult protocol_listen([[maybe_unused]] bool did_allocate_port) { return KSuccess; } virtual KResultOr protocol_receive(ReadonlyBytes /* raw_ipv4_packet */, UserOrKernelBuffer&, size_t, int) { return ENOTIMPL; } virtual KResultOr protocol_send(const UserOrKernelBuffer&, size_t) { return ENOTIMPL; } virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; } diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index ae3767f97e..6d002d1539 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -365,12 +365,15 @@ KResult TCPSocket::protocol_bind() return KSuccess; } -KResult TCPSocket::protocol_listen() +KResult TCPSocket::protocol_listen(bool did_allocate_port) { - Locker locker(sockets_by_tuple().lock()); - if (sockets_by_tuple().resource().contains(tuple())) - return EADDRINUSE; - sockets_by_tuple().resource().set(tuple(), this); + if (!did_allocate_port) { + Locker socket_locker(sockets_by_tuple().lock()); + if (sockets_by_tuple().resource().contains(tuple())) + return EADDRINUSE; + sockets_by_tuple().resource().set(tuple(), this); + } + set_direction(Direction::Passive); set_state(State::Listen); set_setup_state(SetupState::Completed); @@ -387,8 +390,8 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh if (!has_specific_local_address()) set_local_address(routing_decision.adapter->ipv4_address()); - if (auto result = allocate_local_port_if_needed(); result.is_error()) - return result.error(); + if (auto result = allocate_local_port_if_needed(); result.error_or_port.is_error()) + return result.error_or_port.error(); m_sequence_number = get_good_random(); m_ack_number = 0; diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index c3ac398301..024543139f 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -176,7 +176,7 @@ private: virtual KResultOr protocol_allocate_local_port() override; virtual bool protocol_is_disconnected() const override; virtual KResult protocol_bind() override; - virtual KResult protocol_listen() override; + virtual KResult protocol_listen(bool did_allocate_port) override; void enqueue_for_retransmit(); void dequeue_for_retransmit();