diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 019417f813..9c80ac6ff0 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -95,8 +95,23 @@ void IPv4Socket::get_peer_address(sockaddr* address, socklen_t* address_size) *address_size = sizeof(sockaddr_in); } +ErrorOr IPv4Socket::ensure_bound() +{ + dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket::ensure_bound() m_bound {}", m_bound); + if (m_bound) + return {}; + + auto result = protocol_bind(); + if (!result.is_error()) + m_bound = true; + return result; +} + ErrorOr IPv4Socket::bind(Credentials const& credentials, Userspace user_address, socklen_t address_size) { + if (m_bound) + return set_so_error(EINVAL); + VERIFY(setup_state() == SetupState::Unstarted); if (address_size != sizeof(sockaddr_in)) return set_so_error(EINVAL); @@ -120,23 +135,20 @@ ErrorOr IPv4Socket::bind(Credentials const& credentials, Userspace IPv4Socket::listen(size_t backlog) { MutexLocker locker(mutex()); - auto result = allocate_local_port_if_needed(); - if (result.error_or_port.is_error() && result.error_or_port.error().code() != ENOPROTOOPT) - return result.error_or_port.release_error(); - + TRY(ensure_bound()); set_backlog(backlog); set_role(Role::Listener); evaluate_block_conditions(); dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket({}) listening with backlog={}", this, backlog); - return protocol_listen(result.did_allocate); + return protocol_listen(); } ErrorOr IPv4Socket::connect(Credentials const&, OpenFileDescription& description, Userspace address, socklen_t address_size) @@ -176,18 +188,6 @@ bool IPv4Socket::can_write(OpenFileDescription const&, u64) const return true; } -PortAllocationResult IPv4Socket::allocate_local_port_if_needed() -{ - MutexLocker locker(mutex()); - if (m_local_port) - return { m_local_port, false }; - auto port_or_error = protocol_allocate_local_port(); - if (port_or_error.is_error()) - return { port_or_error.release_error(), false }; - m_local_port = port_or_error.release_value(); - return { m_local_port, true }; -} - ErrorOr IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer const& data, size_t data_length, [[maybe_unused]] int flags, Userspace addr, socklen_t addr_length) { MutexLocker locker(mutex()); @@ -220,8 +220,7 @@ ErrorOr IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer cons if (m_local_address.to_u32() == 0) m_local_address = routing_decision.adapter->ipv4_address(); - if (auto result = allocate_local_port_if_needed(); result.error_or_port.is_error() && result.error_or_port.error().code() != ENOPROTOOPT) - return result.error_or_port.release_error(); + TRY(ensure_bound()); dbgln_if(IPV4_SOCKET_DEBUG, "sendto: destination={}:{}", m_peer_address, m_peer_port); diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h index aa0b5878a2..e9ef4eef59 100644 --- a/Kernel/Net/IPv4Socket.h +++ b/Kernel/Net/IPv4Socket.h @@ -21,11 +21,6 @@ class NetworkAdapter; class TCPPacket; class TCPSocket; -struct PortAllocationResult { - ErrorOr error_or_port; - bool did_allocate; -}; - class IPv4Socket : public Socket { public: static ErrorOr> create(int type, int protocol); @@ -76,14 +71,14 @@ protected: IPv4Socket(int type, int protocol, NonnullOwnPtr receive_buffer, OwnPtr optional_scratch_buffer); virtual StringView class_name() const override { return "IPv4Socket"sv; } - PortAllocationResult allocate_local_port_if_needed(); + void set_bound(bool bound) { m_bound = bound; } + ErrorOr ensure_bound(); virtual ErrorOr protocol_bind() { return {}; } - virtual ErrorOr protocol_listen([[maybe_unused]] bool did_allocate_port) { return {}; } + virtual ErrorOr protocol_listen() { return {}; } virtual ErrorOr protocol_receive(ReadonlyBytes /* raw_ipv4_packet */, UserOrKernelBuffer&, size_t, int) { return ENOTIMPL; } virtual ErrorOr protocol_send(UserOrKernelBuffer const&, size_t) { return ENOTIMPL; } virtual ErrorOr protocol_connect(OpenFileDescription&) { return {}; } - virtual ErrorOr protocol_allocate_local_port() { return ENOPROTOOPT; } virtual ErrorOr protocol_size(ReadonlyBytes /* raw_ipv4_packet */) { return ENOTIMPL; } virtual bool protocol_is_disconnected() const { return false; } @@ -108,6 +103,7 @@ private: Vector m_multicast_memberships; bool m_multicast_loop { true }; + bool m_bound { false }; struct ReceivedPacket { IPv4Address peer_address; diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index 2027f969ad..b0dc783807 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -137,6 +137,7 @@ ErrorOr> TCPSocket::try_create_client(IPv4Address const client->set_local_port(new_local_port); client->set_peer_address(new_peer_address); client->set_peer_port(new_peer_port); + client->set_bound(true); client->set_direction(Direction::Incoming); client->set_originator(*this); @@ -414,19 +415,46 @@ NetworkOrdered TCPSocket::compute_tcp_checksum(IPv4Address const& source, I ErrorOr TCPSocket::protocol_bind() { - return m_adapter.with([this](auto& adapter) -> ErrorOr { + dbgln_if(TCP_SOCKET_DEBUG, "TCPSocket::protocol_bind(), local_port() is {}", local_port()); + // Check that we do have the address we're trying to bind to. + TRY(m_adapter.with([this](auto& adapter) -> ErrorOr { if (has_specific_local_address() && !adapter) { adapter = NetworkingManagement::the().from_ipv4_address(local_address()); if (!adapter) return set_so_error(EADDRNOTAVAIL); } return {}; - }); -} + })); -ErrorOr TCPSocket::protocol_listen(bool did_allocate_port) -{ - if (!did_allocate_port) { + if (local_port() == 0) { + // Allocate an unused ephemeral port. + constexpr u16 first_ephemeral_port = 32768; + constexpr u16 last_ephemeral_port = 60999; + constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; + u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; + + return sockets_by_tuple().with_exclusive([&](auto& table) -> ErrorOr { + u16 port = first_scan_port; + while (true) { + IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); + + auto it = table.find(proposed_tuple); + if (it == table.end()) { + set_local_port(port); + table.set(proposed_tuple, this); + dbgln_if(TCP_SOCKET_DEBUG, "...allocated port {}, tuple {}", port, proposed_tuple.to_string()); + return {}; + } + ++port; + if (port > last_ephemeral_port) + port = first_ephemeral_port; + if (port == first_scan_port) + break; + } + return set_so_error(EADDRINUSE); + }); + } else { + // Verify that the user-supplied port is not already used by someone else. bool ok = sockets_by_tuple().with_exclusive([&](auto& table) -> bool { if (table.contains(tuple())) return false; @@ -435,8 +463,12 @@ ErrorOr TCPSocket::protocol_listen(bool did_allocate_port) }); if (!ok) return set_so_error(EADDRINUSE); + return {}; } +} +ErrorOr TCPSocket::protocol_listen() +{ set_direction(Direction::Passive); set_state(State::Listen); set_setup_state(SetupState::Completed); @@ -453,8 +485,7 @@ ErrorOr TCPSocket::protocol_connect(OpenFileDescription& description) if (!has_specific_local_address()) set_local_address(routing_decision.adapter->ipv4_address()); - if (auto result = allocate_local_port_if_needed(); result.error_or_port.is_error()) - return result.error_or_port.release_error(); + TRY(ensure_bound()); m_sequence_number = get_good_random(); m_ack_number = 0; @@ -487,33 +518,6 @@ ErrorOr TCPSocket::protocol_connect(OpenFileDescription& description) return set_so_error(EINPROGRESS); } -ErrorOr TCPSocket::protocol_allocate_local_port() -{ - constexpr u16 first_ephemeral_port = 32768; - constexpr u16 last_ephemeral_port = 60999; - constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; - u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; - - return sockets_by_tuple().with_exclusive([&](auto& table) -> ErrorOr { - for (u16 port = first_scan_port;;) { - IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); - - auto it = table.find(proposed_tuple); - if (it == table.end()) { - set_local_port(port); - table.set(proposed_tuple, this); - return port; - } - ++port; - if (port > last_ephemeral_port) - port = first_ephemeral_port; - if (port == first_scan_port) - break; - } - return set_so_error(EADDRINUSE); - }); -} - bool TCPSocket::protocol_is_disconnected() const { switch (m_state) { diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index a2d58df05e..d018db3d96 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -176,11 +176,10 @@ private: virtual ErrorOr protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override; virtual ErrorOr protocol_send(UserOrKernelBuffer const&, size_t) override; virtual ErrorOr protocol_connect(OpenFileDescription&) override; - virtual ErrorOr protocol_allocate_local_port() override; virtual ErrorOr protocol_size(ReadonlyBytes raw_ipv4_packet) override; virtual bool protocol_is_disconnected() const override; virtual ErrorOr protocol_bind() override; - virtual ErrorOr protocol_listen(bool did_allocate_port) override; + virtual ErrorOr protocol_listen() override; void enqueue_for_retransmit(); void dequeue_for_retransmit(); diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index a9c1c46eef..5cb5814cce 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -108,44 +108,47 @@ ErrorOr UDPSocket::protocol_send(UserOrKernelBuffer const& data, size_t ErrorOr UDPSocket::protocol_connect(OpenFileDescription&) { + TRY(ensure_bound()); set_role(Role::Connected); set_connected(true); return {}; } -ErrorOr UDPSocket::protocol_allocate_local_port() -{ - constexpr u16 first_ephemeral_port = 32768; - constexpr u16 last_ephemeral_port = 60999; - constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; - u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; - - return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr { - for (u16 port = first_scan_port;;) { - auto it = table.find(port); - if (it == table.end()) { - set_local_port(port); - table.set(port, this); - return port; - } - ++port; - if (port > last_ephemeral_port) - port = first_ephemeral_port; - if (port == first_scan_port) - break; - } - return set_so_error(EADDRINUSE); - }); -} - ErrorOr UDPSocket::protocol_bind() { - return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr { - if (table.contains(local_port())) + if (local_port() == 0) { + // Allocate an unused ephemeral port. + constexpr u16 first_ephemeral_port = 32768; + constexpr u16 last_ephemeral_port = 60999; + constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; + u16 first_scan_port = first_ephemeral_port + get_good_random() % ephemeral_port_range_size; + + return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr { + u16 port = first_scan_port; + while (true) { + auto it = table.find(port); + if (it == table.end()) { + set_local_port(port); + table.set(port, this); + return {}; + } + ++port; + if (port > last_ephemeral_port) + port = first_ephemeral_port; + if (port == first_scan_port) + break; + } return set_so_error(EADDRINUSE); - table.set(local_port(), this); - return {}; - }); + }); + } else { + // Verify that the user-supplied port is not already used by someone else. + return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr { + if (table.contains(local_port())) + return set_so_error(EADDRINUSE); + table.set(local_port(), this); + return {}; + }); + } } } diff --git a/Kernel/Net/UDPSocket.h b/Kernel/Net/UDPSocket.h index d904f97e24..32b5b05334 100644 --- a/Kernel/Net/UDPSocket.h +++ b/Kernel/Net/UDPSocket.h @@ -30,7 +30,6 @@ private: virtual ErrorOr protocol_send(UserOrKernelBuffer const&, size_t) override; virtual ErrorOr protocol_size(ReadonlyBytes raw_ipv4_packet) override; virtual ErrorOr protocol_connect(OpenFileDescription&) override; - virtual ErrorOr protocol_allocate_local_port() override; virtual ErrorOr protocol_bind() override; };