mirror of
				https://github.com/RGBCube/serenity
				synced 2025-10-31 17:02:45 +00:00 
			
		
		
		
	Kernel: Support non-blocking connect().
If connect() is called on a non-blocking socket, it will "fail" immediately with -EINPROGRESS. After that, you select() on the socket and wait for it to become writable.
This commit is contained in:
		
							parent
							
								
									7fcca0ce4b
								
							
						
					
					
						commit
						65d6318c33
					
				
					 11 changed files with 22 additions and 17 deletions
				
			
		|  | @ -66,7 +66,7 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size) | ||||||
|     ASSERT_NOT_REACHED(); |     ASSERT_NOT_REACHED(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size) | KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size, ShouldBlock should_block) | ||||||
| { | { | ||||||
|     ASSERT(!m_bound); |     ASSERT(!m_bound); | ||||||
|     if (address_size != sizeof(sockaddr_in)) |     if (address_size != sizeof(sockaddr_in)) | ||||||
|  | @ -78,7 +78,7 @@ KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size) | ||||||
|     m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr); |     m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr); | ||||||
|     m_destination_port = ntohs(ia.sin_port); |     m_destination_port = ntohs(ia.sin_port); | ||||||
| 
 | 
 | ||||||
|     return protocol_connect(); |     return protocol_connect(should_block); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void IPv4Socket::attach_fd(SocketRole) | void IPv4Socket::attach_fd(SocketRole) | ||||||
|  | @ -110,7 +110,7 @@ ssize_t IPv4Socket::write(SocketRole, const byte* data, ssize_t size) | ||||||
| 
 | 
 | ||||||
| bool IPv4Socket::can_write(SocketRole) const | bool IPv4Socket::can_write(SocketRole) const | ||||||
| { | { | ||||||
|     return true; |     return is_connected(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| int IPv4Socket::allocate_source_port_if_needed() | int IPv4Socket::allocate_source_port_if_needed() | ||||||
|  |  | ||||||
|  | @ -21,7 +21,7 @@ public: | ||||||
|     static Lockable<HashTable<IPv4Socket*>>& all_sockets(); |     static Lockable<HashTable<IPv4Socket*>>& all_sockets(); | ||||||
| 
 | 
 | ||||||
|     virtual KResult bind(const sockaddr*, socklen_t) override; |     virtual KResult bind(const sockaddr*, socklen_t) override; | ||||||
|     virtual KResult connect(const sockaddr*, socklen_t) override; |     virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; | ||||||
|     virtual bool get_address(sockaddr*, socklen_t*) override; |     virtual bool get_address(sockaddr*, socklen_t*) override; | ||||||
|     virtual void attach_fd(SocketRole) override; |     virtual void attach_fd(SocketRole) override; | ||||||
|     virtual void detach_fd(SocketRole) override; |     virtual void detach_fd(SocketRole) override; | ||||||
|  | @ -49,7 +49,7 @@ protected: | ||||||
| 
 | 
 | ||||||
|     virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; } |     virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; } | ||||||
|     virtual int protocol_send(const void*, int) { return -ENOTIMPL; } |     virtual int protocol_send(const void*, int) { return -ENOTIMPL; } | ||||||
|     virtual KResult protocol_connect() { return KSuccess; } |     virtual KResult protocol_connect(ShouldBlock) { return KSuccess; } | ||||||
|     virtual int protocol_allocate_source_port() { return 0; } |     virtual int protocol_allocate_source_port() { return 0; } | ||||||
|     virtual bool protocol_is_disconnected() const { return false; } |     virtual bool protocol_is_disconnected() const { return false; } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -65,7 +65,7 @@ KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size) | ||||||
|     return KSuccess; |     return KSuccess; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| KResult LocalSocket::connect(const sockaddr* address, socklen_t address_size) | KResult LocalSocket::connect(const sockaddr* address, socklen_t address_size, ShouldBlock) | ||||||
| { | { | ||||||
|     ASSERT(!m_bound); |     ASSERT(!m_bound); | ||||||
|     if (address_size != sizeof(sockaddr_un)) |     if (address_size != sizeof(sockaddr_un)) | ||||||
|  |  | ||||||
|  | @ -11,7 +11,7 @@ public: | ||||||
|     virtual ~LocalSocket() override; |     virtual ~LocalSocket() override; | ||||||
| 
 | 
 | ||||||
|     virtual KResult bind(const sockaddr*, socklen_t) override; |     virtual KResult bind(const sockaddr*, socklen_t) override; | ||||||
|     virtual KResult connect(const sockaddr*, socklen_t) override; |     virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; | ||||||
|     virtual bool get_address(sockaddr*, socklen_t*) override; |     virtual bool get_address(sockaddr*, socklen_t*) override; | ||||||
|     virtual void attach_fd(SocketRole) override; |     virtual void attach_fd(SocketRole) override; | ||||||
|     virtual void detach_fd(SocketRole) override; |     virtual void detach_fd(SocketRole) override; | ||||||
|  |  | ||||||
|  | @ -342,6 +342,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) | ||||||
|         socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); |         socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||||
|         socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK); |         socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK); | ||||||
|         socket->set_state(TCPSocket::State::Disconnecting); |         socket->set_state(TCPSocket::State::Disconnecting); | ||||||
|  |         socket->set_connected(false); | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -9,6 +9,7 @@ | ||||||
| #include <Kernel/KResult.h> | #include <Kernel/KResult.h> | ||||||
| 
 | 
 | ||||||
| enum class SocketRole { None, Listener, Accepted, Connected, Connecting }; | enum class SocketRole { None, Listener, Accepted, Connected, Connecting }; | ||||||
|  | enum class ShouldBlock { No = 0, Yes = 1 }; | ||||||
| 
 | 
 | ||||||
| class Socket : public Retainable<Socket> { | class Socket : public Retainable<Socket> { | ||||||
| public: | public: | ||||||
|  | @ -25,7 +26,7 @@ public: | ||||||
|     KResult listen(int backlog); |     KResult listen(int backlog); | ||||||
| 
 | 
 | ||||||
|     virtual KResult bind(const sockaddr*, socklen_t) = 0; |     virtual KResult bind(const sockaddr*, socklen_t) = 0; | ||||||
|     virtual KResult connect(const sockaddr*, socklen_t) = 0; |     virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock) = 0; | ||||||
|     virtual bool get_address(sockaddr*, socklen_t*) = 0; |     virtual bool get_address(sockaddr*, socklen_t*) = 0; | ||||||
|     virtual bool is_local() const { return false; } |     virtual bool is_local() const { return false; } | ||||||
|     virtual bool is_ipv4() const { return false; } |     virtual bool is_ipv4() const { return false; } | ||||||
|  |  | ||||||
|  | @ -152,7 +152,7 @@ NetworkOrdered<word> TCPSocket::compute_tcp_checksum(const IPv4Address& source, | ||||||
|     return ~(checksum & 0xffff); |     return ~(checksum & 0xffff); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| KResult TCPSocket::protocol_connect() | KResult TCPSocket::protocol_connect(ShouldBlock should_block) | ||||||
| { | { | ||||||
|     auto* adapter = adapter_for_route_to(destination_address()); |     auto* adapter = adapter_for_route_to(destination_address()); | ||||||
|     if (!adapter) |     if (!adapter) | ||||||
|  | @ -166,11 +166,14 @@ KResult TCPSocket::protocol_connect() | ||||||
|     send_tcp_packet(TCPFlags::SYN); |     send_tcp_packet(TCPFlags::SYN); | ||||||
|     m_state = State::Connecting; |     m_state = State::Connecting; | ||||||
| 
 | 
 | ||||||
|     current->set_blocked_socket(this); |     if (should_block == ShouldBlock::Yes) { | ||||||
|     current->block(Thread::BlockedConnect); |         current->set_blocked_socket(this); | ||||||
|  |         current->block(Thread::BlockedConnect); | ||||||
|  |         ASSERT(is_connected()); | ||||||
|  |         return KSuccess; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     ASSERT(is_connected()); |     return KResult(-EINPROGRESS); | ||||||
|     return KSuccess; |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| int TCPSocket::protocol_allocate_source_port() | int TCPSocket::protocol_allocate_source_port() | ||||||
|  |  | ||||||
|  | @ -34,7 +34,7 @@ private: | ||||||
| 
 | 
 | ||||||
|     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override; |     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override; | ||||||
|     virtual int protocol_send(const void*, int) override; |     virtual int protocol_send(const void*, int) override; | ||||||
|     virtual KResult protocol_connect() override; |     virtual KResult protocol_connect(ShouldBlock) override; | ||||||
|     virtual int protocol_allocate_source_port() override; |     virtual int protocol_allocate_source_port() override; | ||||||
|     virtual bool protocol_is_disconnected() const override; |     virtual bool protocol_is_disconnected() const override; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -81,7 +81,7 @@ int UDPSocket::protocol_send(const void* data, int data_length) | ||||||
|     return data_length; |     return data_length; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| KResult UDPSocket::protocol_connect() | KResult UDPSocket::protocol_connect(ShouldBlock) | ||||||
| { | { | ||||||
|     return KSuccess; |     return KSuccess; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -17,7 +17,7 @@ private: | ||||||
| 
 | 
 | ||||||
|     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override; |     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override; | ||||||
|     virtual int protocol_send(const void*, int) override; |     virtual int protocol_send(const void*, int) override; | ||||||
|     virtual KResult protocol_connect() override; |     virtual KResult protocol_connect(ShouldBlock) override; | ||||||
|     virtual int protocol_allocate_source_port() override; |     virtual int protocol_allocate_source_port() override; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -2038,7 +2038,7 @@ int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_ | ||||||
|         return -EISCONN; |         return -EISCONN; | ||||||
|     auto& socket = *descriptor->socket(); |     auto& socket = *descriptor->socket(); | ||||||
|     descriptor->set_socket_role(SocketRole::Connecting); |     descriptor->set_socket_role(SocketRole::Connecting); | ||||||
|     auto result = socket.connect(address, address_size); |     auto result = socket.connect(address, address_size, descriptor->is_blocking() ? ShouldBlock::Yes : ShouldBlock::No); | ||||||
|     if (result.is_error()) { |     if (result.is_error()) { | ||||||
|         descriptor->set_socket_role(SocketRole::None); |         descriptor->set_socket_role(SocketRole::None); | ||||||
|         return result; |         return result; | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Andreas Kling
						Andreas Kling