mirror of
				https://github.com/RGBCube/serenity
				synced 2025-10-31 16:02:45 +00:00 
			
		
		
		
	Kernel: Refactor TCP/IP stack
This has several significant changes to the networking stack. * Significant refactoring of the TCP state machine. Right now it's probably more fragile than it used to be, but handles quite a lot more of the handshake process. * `TCPSocket` holds a `NetworkAdapter*`, assigned during `connect()` or `bind()`, whichever comes first. * `listen()` is now virtual in `Socket` and intended to be implemented in its child classes * `listen()` no longer works without `bind()` - this is a bit of a regression, but listening sockets didn't work at all before, so it's not possible to observe the regression. * A file is exposed at `/proc/net_tcp`, which is a JSON document listing the current TCP sockets with a bit of metadata. * There's an `ETHERNET_VERY_DEBUG` flag for dumping packet's content out to `kprintf`. It is, indeed, _very debug_.
This commit is contained in:
		
							parent
							
								
									c973a51a23
								
							
						
					
					
						commit
						73c998dbfc
					
				
					 12 changed files with 446 additions and 84 deletions
				
			
		|  | @ -67,6 +67,7 @@ public: | |||
|     } | ||||
| 
 | ||||
|     in_addr_t to_in_addr_t() const { return m_data_as_u32; } | ||||
|     u32 to_u32() const { return m_data_as_u32; } | ||||
| 
 | ||||
|     bool operator==(const IPv4Address& other) const { return m_data_as_u32 == other.m_data_as_u32; } | ||||
|     bool operator!=(const IPv4Address& other) const { return m_data_as_u32 != other.m_data_as_u32; } | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ | |||
| #include <Kernel/FileSystem/VirtualFileSystem.h> | ||||
| #include <Kernel/KParams.h> | ||||
| #include <Kernel/Net/NetworkAdapter.h> | ||||
| #include <Kernel/Net/TCPSocket.h> | ||||
| #include <Kernel/PCI.h> | ||||
| #include <Kernel/VM/MemoryManager.h> | ||||
| #include <Kernel/kmalloc.h> | ||||
|  | @ -46,6 +47,7 @@ enum ProcFileType { | |||
|     FI_Root_uptime, | ||||
|     FI_Root_cmdline, | ||||
|     FI_Root_netadapters, | ||||
|     FI_Root_net_tcp, | ||||
|     FI_Root_self, // symlink
 | ||||
|     FI_Root_sys,  // directory
 | ||||
|     __FI_Root_End, | ||||
|  | @ -278,6 +280,23 @@ Optional<KBuffer> procfs$netadapters(InodeIdentifier) | |||
|     return builder.to_byte_buffer(); | ||||
| } | ||||
| 
 | ||||
| Optional<KBuffer> procfs$net_tcp(InodeIdentifier) | ||||
| { | ||||
|     JsonArray json; | ||||
|     TCPSocket::for_each([&json](auto& socket) { | ||||
|         JsonObject obj; | ||||
|         obj.set("local_address", socket->local_address().to_string()); | ||||
|         obj.set("local_port", socket->local_port()); | ||||
|         obj.set("peer_address", socket->peer_address().to_string()); | ||||
|         obj.set("peer_port", socket->peer_port()); | ||||
|         obj.set("state", TCPSocket::to_string(socket->state())); | ||||
|         obj.set("ack_number", socket->ack_number()); | ||||
|         obj.set("sequence_number", socket->sequence_number()); | ||||
|         json.append(obj); | ||||
|     }); | ||||
|     return json.serialized().to_byte_buffer(); | ||||
| } | ||||
| 
 | ||||
| Optional<KBuffer> procfs$pid_vmo(InodeIdentifier identifier) | ||||
| { | ||||
|     auto handle = ProcessInspectionHandle::from_pid(to_pid(identifier)); | ||||
|  | @ -1077,6 +1096,7 @@ ProcFS::ProcFS() | |||
|     m_entries[FI_Root_uptime] = { "uptime", FI_Root_uptime, procfs$uptime }; | ||||
|     m_entries[FI_Root_cmdline] = { "cmdline", FI_Root_cmdline, procfs$cmdline }; | ||||
|     m_entries[FI_Root_netadapters] = { "netadapters", FI_Root_netadapters, procfs$netadapters }; | ||||
|     m_entries[FI_Root_net_tcp] = { "net_tcp", FI_Root_net_tcp, procfs$net_tcp }; | ||||
|     m_entries[FI_Root_sys] = { "sys", FI_Root_sys }; | ||||
| 
 | ||||
|     m_entries[FI_PID_vm] = { "vm", FI_PID_vm, procfs$pid_vm }; | ||||
|  |  | |||
|  | @ -89,6 +89,22 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size) | |||
|     return protocol_bind(); | ||||
| } | ||||
| 
 | ||||
| KResult IPv4Socket::listen(int backlog) | ||||
| { | ||||
|     int rc = allocate_local_port_if_needed(); | ||||
|     if (rc < 0) | ||||
|         return KResult(-EADDRINUSE); | ||||
| 
 | ||||
|     if (m_local_address.to_u32() == 0) | ||||
|         return KResult(-EADDRINUSE); | ||||
| 
 | ||||
|     set_backlog(backlog); | ||||
| 
 | ||||
|     kprintf("IPv4Socket{%p} listening with backlog=%d\n", this, backlog); | ||||
| 
 | ||||
|     return protocol_listen(); | ||||
| } | ||||
| 
 | ||||
| KResult IPv4Socket::connect(FileDescription& description, const sockaddr* address, socklen_t address_size, ShouldBlock should_block) | ||||
| { | ||||
|     if (address_size != sizeof(sockaddr_in)) | ||||
|  | @ -157,6 +173,9 @@ ssize_t IPv4Socket::sendto(FileDescription&, const void* data, size_t data_lengt | |||
|     if (!adapter) | ||||
|         return -EHOSTUNREACH; | ||||
| 
 | ||||
|     if (m_local_address.to_u32() == 0) | ||||
|         m_local_address = adapter->ipv4_address(); | ||||
| 
 | ||||
|     int rc = allocate_local_port_if_needed(); | ||||
|     if (rc < 0) | ||||
|         return rc; | ||||
|  |  | |||
|  | @ -2,10 +2,11 @@ | |||
| 
 | ||||
| #include <AK/HashMap.h> | ||||
| #include <AK/SinglyLinkedList.h> | ||||
| #include <Kernel/KBuffer.h> | ||||
| #include <Kernel/DoubleBuffer.h> | ||||
| #include <Kernel/KBuffer.h> | ||||
| #include <Kernel/Lock.h> | ||||
| #include <Kernel/Net/IPv4.h> | ||||
| #include <Kernel/Net/IPv4SocketTuple.h> | ||||
| #include <Kernel/Net/Socket.h> | ||||
| 
 | ||||
| class IPv4SocketHandle; | ||||
|  | @ -23,6 +24,7 @@ public: | |||
| 
 | ||||
|     virtual KResult bind(const sockaddr*, socklen_t) override; | ||||
|     virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; | ||||
|     virtual KResult listen(int) override; | ||||
|     virtual bool get_local_address(sockaddr*, socklen_t*) override; | ||||
|     virtual bool get_peer_address(sockaddr*, socklen_t*) override; | ||||
|     virtual void attach(FileDescription&) override; | ||||
|  | @ -34,7 +36,7 @@ public: | |||
| 
 | ||||
|     void did_receive(const IPv4Address& peer_address, u16 peer_port, KBuffer&&); | ||||
| 
 | ||||
|     const IPv4Address& local_address() const; | ||||
|     const IPv4Address& local_address() const { return m_local_address; } | ||||
|     u16 local_port() const { return m_local_port; } | ||||
|     void set_local_port(u16 port) { m_local_port = port; } | ||||
| 
 | ||||
|  | @ -42,6 +44,8 @@ public: | |||
|     u16 peer_port() const { return m_peer_port; } | ||||
|     void set_peer_port(u16 port) { m_peer_port = port; } | ||||
| 
 | ||||
|     IPv4SocketTuple tuple() const { return IPv4SocketTuple(m_local_address, m_local_port, m_peer_address, m_peer_port); } | ||||
| 
 | ||||
| protected: | ||||
|     IPv4Socket(int type, int protocol); | ||||
|     virtual const char* class_name() const override { return "IPv4Socket"; } | ||||
|  | @ -49,12 +53,16 @@ protected: | |||
|     int allocate_local_port_if_needed(); | ||||
| 
 | ||||
|     virtual KResult protocol_bind() { return KSuccess; } | ||||
|     virtual KResult protocol_listen() { return KSuccess; } | ||||
|     virtual int protocol_receive(const KBuffer&, void*, size_t, int) { return -ENOTIMPL; } | ||||
|     virtual int protocol_send(const void*, int) { return -ENOTIMPL; } | ||||
|     virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; } | ||||
|     virtual int protocol_allocate_local_port() { return 0; } | ||||
|     virtual bool protocol_is_disconnected() const { return false; } | ||||
| 
 | ||||
|     void set_local_address(IPv4Address address) { m_local_address = address; } | ||||
|     void set_peer_address(IPv4Address address) { m_peer_address = address; } | ||||
| 
 | ||||
| private: | ||||
|     virtual bool is_ipv4() const override { return true; } | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										63
									
								
								Kernel/Net/IPv4SocketTuple.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								Kernel/Net/IPv4SocketTuple.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,63 @@ | |||
| #pragma once | ||||
| 
 | ||||
| #include <AK/HashMap.h> | ||||
| #include <AK/SinglyLinkedList.h> | ||||
| #include <Kernel/DoubleBuffer.h> | ||||
| #include <Kernel/KBuffer.h> | ||||
| #include <Kernel/Lock.h> | ||||
| #include <Kernel/Net/IPv4.h> | ||||
| #include <Kernel/Net/Socket.h> | ||||
| 
 | ||||
| class IPv4SocketTuple { | ||||
| public: | ||||
|     IPv4SocketTuple(IPv4Address local_address, u16 local_port, IPv4Address peer_address, u16 peer_port) | ||||
|         : m_local_address(local_address) | ||||
|         , m_local_port(local_port) | ||||
|         , m_peer_address(peer_address) | ||||
|         , m_peer_port(peer_port) {}; | ||||
| 
 | ||||
|     IPv4Address local_address() const { return m_local_address; }; | ||||
|     u16 local_port() const { return m_local_port; }; | ||||
|     IPv4Address peer_address() const { return m_peer_address; }; | ||||
|     u16 peer_port() const { return m_peer_port; }; | ||||
| 
 | ||||
|     bool operator==(const IPv4SocketTuple other) const | ||||
|     { | ||||
|         return other.local_address() == m_local_address && other.local_port() == m_local_port && other.peer_address() == m_peer_address && other.peer_port() == m_peer_port; | ||||
|     }; | ||||
| 
 | ||||
|     String to_string() const | ||||
|     { | ||||
|         return String::format( | ||||
|             "%s:%d -> %s:%d", | ||||
|             m_local_address.to_string().characters(), | ||||
|             m_local_port, | ||||
|             m_peer_address.to_string().characters(), | ||||
|             m_peer_port); | ||||
|     } | ||||
| 
 | ||||
| private: | ||||
|     IPv4Address m_local_address; | ||||
|     u16 m_local_port { 0 }; | ||||
|     IPv4Address m_peer_address; | ||||
|     u16 m_peer_port { 0 }; | ||||
| }; | ||||
| 
 | ||||
| namespace AK { | ||||
| 
 | ||||
| template<> | ||||
| struct Traits<IPv4SocketTuple> : public GenericTraits<IPv4SocketTuple> { | ||||
|     static unsigned hash(const IPv4SocketTuple& tuple) | ||||
|     { | ||||
|         auto h1 = pair_int_hash(tuple.local_address().to_u32(), tuple.local_port()); | ||||
|         auto h2 = pair_int_hash(tuple.peer_address().to_u32(), tuple.peer_port()); | ||||
|         return pair_int_hash(h1, h2); | ||||
|     } | ||||
| 
 | ||||
|     static void dump(const IPv4SocketTuple& tuple) | ||||
|     { | ||||
|         kprintf("%s", tuple.to_string().characters()); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| } | ||||
|  | @ -114,6 +114,16 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre | |||
|     return KSuccess; | ||||
| } | ||||
| 
 | ||||
| KResult LocalSocket::listen(int backlog) | ||||
| { | ||||
|     LOCKER(lock()); | ||||
|     if (type() != SOCK_STREAM) | ||||
|         return KResult(-EOPNOTSUPP); | ||||
|     set_backlog(backlog); | ||||
|     kprintf("LocalSocket{%p} listening with backlog=%d\n", this, backlog); | ||||
|     return KSuccess; | ||||
| } | ||||
| 
 | ||||
| void LocalSocket::attach(FileDescription& description) | ||||
| { | ||||
|     switch (description.socket_role()) { | ||||
|  |  | |||
|  | @ -13,6 +13,7 @@ public: | |||
|     // ^Socket
 | ||||
|     virtual KResult bind(const sockaddr*, socklen_t) override; | ||||
|     virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; | ||||
|     virtual KResult listen(int) override; | ||||
|     virtual bool get_local_address(sockaddr*, socklen_t*) override; | ||||
|     virtual bool get_peer_address(sockaddr*, socklen_t*) override; | ||||
|     virtual void attach(FileDescription&) override; | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ | |||
| #include <Kernel/Process.h> | ||||
| 
 | ||||
| //#define ETHERNET_DEBUG
 | ||||
| //#define ETHERNET_VERY_DEBUG
 | ||||
| //#define IPV4_DEBUG
 | ||||
| //#define ICMP_DEBUG
 | ||||
| //#define UDP_DEBUG
 | ||||
|  | @ -84,6 +85,28 @@ void NetworkTask_main() | |||
|             packet.size()); | ||||
| #endif | ||||
| 
 | ||||
| #ifdef ETHERNET_VERY_DEBUG | ||||
|         u8* data = packet.data(); | ||||
| 
 | ||||
|         for (size_t i = 0; i < packet.size(); i++) { | ||||
|             kprintf("%b", data[i]); | ||||
| 
 | ||||
|             switch (i % 16) { | ||||
|             case 7: | ||||
|                 kprintf("  "); | ||||
|                 break; | ||||
|             case 15: | ||||
|                 kprintf("\n"); | ||||
|                 break; | ||||
|             default: | ||||
|                 kprintf(" "); | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         kprintf("\n"); | ||||
| #endif | ||||
| 
 | ||||
|         switch (eth.ether_type()) { | ||||
|         case EtherType::ARP: | ||||
|             handle_arp(eth, packet.size()); | ||||
|  | @ -279,7 +302,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) | |||
|     size_t payload_size = ipv4_packet.payload_size() - tcp_packet.header_size(); | ||||
| 
 | ||||
| #ifdef TCP_DEBUG | ||||
|     kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s %s), window_size=%u, payload_size=%u\n", | ||||
|     kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s%s%s%s), window_size=%u, payload_size=%u\n", | ||||
|         ipv4_packet.source().to_string().characters(), | ||||
|         tcp_packet.source_port(), | ||||
|         ipv4_packet.destination().to_string().characters(), | ||||
|  | @ -287,15 +310,19 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) | |||
|         tcp_packet.sequence_number(), | ||||
|         tcp_packet.ack_number(), | ||||
|         tcp_packet.flags(), | ||||
|         tcp_packet.has_syn() ? "SYN" : "", | ||||
|         tcp_packet.has_ack() ? "ACK" : "", | ||||
|         tcp_packet.has_syn() ? "SYN " : "", | ||||
|         tcp_packet.has_ack() ? "ACK " : "", | ||||
|         tcp_packet.has_fin() ? "FIN " : "", | ||||
|         tcp_packet.has_rst() ? "RST " : "", | ||||
|         tcp_packet.window_size(), | ||||
|         payload_size); | ||||
| #endif | ||||
| 
 | ||||
|     auto socket = TCPSocket::from_port(tcp_packet.destination_port()); | ||||
|     IPv4SocketTuple tuple(ipv4_packet.destination(), tcp_packet.destination_port(), ipv4_packet.source(), tcp_packet.source_port()); | ||||
| 
 | ||||
|     auto socket = TCPSocket::from_tuple(tuple); | ||||
|     if (!socket) { | ||||
|         kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port()); | ||||
|         kprintf("handle_tcp: No TCP socket for tuple %s\n", tuple.to_string().characters()); | ||||
|         return; | ||||
|     } | ||||
| 
 | ||||
|  | @ -307,39 +334,168 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) | |||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     if (tcp_packet.has_syn() && tcp_packet.has_ack()) { | ||||
|         socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|         socket->send_tcp_packet(TCPFlags::ACK); | ||||
|         socket->set_connected(true); | ||||
|         kprintf("handle_tcp: Connection established!\n"); | ||||
|         socket->set_state(TCPSocket::State::Connected); | ||||
|         return; | ||||
|     } | ||||
| #ifdef TCP_DEBUG | ||||
|     kprintf("handle_tcp: state=%s\n", TCPSocket::to_string(socket->state())); | ||||
| #endif | ||||
| 
 | ||||
|     if (tcp_packet.has_fin()) { | ||||
|         kprintf("handle_tcp: Got FIN, payload_size=%u\n", payload_size); | ||||
|     switch (socket->state()) { | ||||
|     case TCPSocket::State::Closed: | ||||
|         kprintf("handle_tcp: unexpected flags in Closed state\n"); | ||||
|         socket->send_tcp_packet(TCPFlags::RST); | ||||
|         socket->set_state(TCPSocket::State::Closed); | ||||
|         kprintf("handle_tcp: Closed -> Closed\n"); | ||||
|         return; | ||||
|     case TCPSocket::State::TimeWait: | ||||
|         kprintf("handle_tcp: unexpected flags in TimeWait state\n"); | ||||
|         socket->send_tcp_packet(TCPFlags::RST); | ||||
|         socket->set_state(TCPSocket::State::Closed); | ||||
|         kprintf("handle_tcp: TimeWait -> Closed\n"); | ||||
|         return; | ||||
|     case TCPSocket::State::Listen: | ||||
|         switch (tcp_packet.flags()) { | ||||
|         case TCPFlags::SYN: | ||||
|             kprintf("handle_tcp: incoming connections not supported\n"); | ||||
|             // socket->send_tcp_packet(TCPFlags::RST);
 | ||||
|             return; | ||||
|         default: | ||||
|             kprintf("handle_tcp: unexpected flags in Listen state\n"); | ||||
|             // socket->send_tcp_packet(TCPFlags::RST);
 | ||||
|             return; | ||||
|         } | ||||
|     case TCPSocket::State::SynSent: | ||||
|         switch (tcp_packet.flags()) { | ||||
|         case TCPFlags::SYN: | ||||
|             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|             socket->send_tcp_packet(TCPFlags::ACK); | ||||
|             socket->set_state(TCPSocket::State::SynReceived); | ||||
|             kprintf("handle_tcp: SynSent -> SynReceived\n"); | ||||
|             return; | ||||
|         case TCPFlags::SYN | TCPFlags::ACK: | ||||
|             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|             socket->send_tcp_packet(TCPFlags::ACK); | ||||
|             socket->set_state(TCPSocket::State::Established); | ||||
|             socket->set_connected(true); | ||||
|             kprintf("handle_tcp: SynSent -> Established\n"); | ||||
|             return; | ||||
|         default: | ||||
|             kprintf("handle_tcp: unexpected flags in SynSent state\n"); | ||||
|             socket->send_tcp_packet(TCPFlags::RST); | ||||
|             socket->set_state(TCPSocket::State::Closed); | ||||
|             kprintf("handle_tcp: SynSent -> Closed\n"); | ||||
|             return; | ||||
|         } | ||||
|     case TCPSocket::State::SynReceived: | ||||
|         switch (tcp_packet.flags()) { | ||||
|         case TCPFlags::ACK: | ||||
|             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|             socket->set_state(TCPSocket::State::Established); | ||||
|             socket->set_connected(true); | ||||
|             kprintf("handle_tcp: SynReceived -> Established\n"); | ||||
|             return; | ||||
|         default: | ||||
|             kprintf("handle_tcp: unexpected flags in SynReceived state\n"); | ||||
|             socket->send_tcp_packet(TCPFlags::RST); | ||||
|             socket->set_state(TCPSocket::State::Closed); | ||||
|             kprintf("handle_tcp: SynReceived -> Closed\n"); | ||||
|             return; | ||||
|         } | ||||
|     case TCPSocket::State::CloseWait: | ||||
|         switch (tcp_packet.flags()) { | ||||
|         default: | ||||
|             kprintf("handle_tcp: unexpected flags in CloseWait state\n"); | ||||
|             socket->send_tcp_packet(TCPFlags::RST); | ||||
|             socket->set_state(TCPSocket::State::Closed); | ||||
|             kprintf("handle_tcp: CloseWait -> Closed\n"); | ||||
|             return; | ||||
|         } | ||||
|     case TCPSocket::State::LastAck: | ||||
|         switch (tcp_packet.flags()) { | ||||
|         case TCPFlags::ACK: | ||||
|             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|             socket->set_state(TCPSocket::State::Closed); | ||||
|             kprintf("handle_tcp: LastAck -> Closed\n"); | ||||
|             return; | ||||
|         default: | ||||
|             kprintf("handle_tcp: unexpected flags in LastAck state\n"); | ||||
|             socket->send_tcp_packet(TCPFlags::RST); | ||||
|             socket->set_state(TCPSocket::State::Closed); | ||||
|             kprintf("handle_tcp: LastAck -> Closed\n"); | ||||
|             return; | ||||
|         } | ||||
|     case TCPSocket::State::FinWait1: | ||||
|         switch (tcp_packet.flags()) { | ||||
|         case TCPFlags::ACK: | ||||
|             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|             socket->set_state(TCPSocket::State::FinWait2); | ||||
|             kprintf("handle_tcp: FinWait1 -> FinWait2\n"); | ||||
|             return; | ||||
|         case TCPFlags::FIN: | ||||
|             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|             socket->set_state(TCPSocket::State::Closing); | ||||
|             kprintf("handle_tcp: FinWait1 -> Closing\n"); | ||||
|             return; | ||||
|         default: | ||||
|             kprintf("handle_tcp: unexpected flags in FinWait1 state\n"); | ||||
|             socket->send_tcp_packet(TCPFlags::RST); | ||||
|             socket->set_state(TCPSocket::State::Closed); | ||||
|             kprintf("handle_tcp: FinWait1 -> Closed\n"); | ||||
|             return; | ||||
|         } | ||||
|     case TCPSocket::State::FinWait2: | ||||
|         switch (tcp_packet.flags()) { | ||||
|         case TCPFlags::FIN: | ||||
|             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|             socket->set_state(TCPSocket::State::TimeWait); | ||||
|             kprintf("handle_tcp: FinWait2 -> TimeWait\n"); | ||||
|             return; | ||||
|         default: | ||||
|             kprintf("handle_tcp: unexpected flags in FinWait2 state\n"); | ||||
|             socket->send_tcp_packet(TCPFlags::RST); | ||||
|             socket->set_state(TCPSocket::State::Closed); | ||||
|             kprintf("handle_tcp: FinWait2 -> Closed\n"); | ||||
|             return; | ||||
|         } | ||||
|     case TCPSocket::State::Closing: | ||||
|         switch (tcp_packet.flags()) { | ||||
|         case TCPFlags::ACK: | ||||
|             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|             socket->set_state(TCPSocket::State::TimeWait); | ||||
|             kprintf("handle_tcp: Closing -> TimeWait\n"); | ||||
|             return; | ||||
|         default: | ||||
|             kprintf("handle_tcp: unexpected flags in Closing state\n"); | ||||
|             socket->send_tcp_packet(TCPFlags::RST); | ||||
|             socket->set_state(TCPSocket::State::Closed); | ||||
|             kprintf("handle_tcp: Closing -> Closed\n"); | ||||
|             return; | ||||
|         } | ||||
|     case TCPSocket::State::Established: | ||||
|         if (tcp_packet.has_fin()) { | ||||
|             if (payload_size != 0) | ||||
|                 socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); | ||||
| 
 | ||||
|             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|             socket->send_tcp_packet(TCPFlags::ACK); | ||||
|             socket->set_state(TCPSocket::State::CloseWait); | ||||
|             socket->set_connected(false); | ||||
|             kprintf("handle_tcp: Established -> CloseWait\n"); | ||||
|             return; | ||||
|         } | ||||
| 
 | ||||
|         socket->set_ack_number(tcp_packet.sequence_number() + payload_size); | ||||
| 
 | ||||
| #ifdef TCP_DEBUG | ||||
|         kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n", | ||||
|             tcp_packet.ack_number(), | ||||
|             tcp_packet.sequence_number(), | ||||
|             payload_size, | ||||
|             socket->ack_number(), | ||||
|             socket->sequence_number()); | ||||
| #endif | ||||
| 
 | ||||
|         socket->send_tcp_packet(TCPFlags::ACK); | ||||
| 
 | ||||
|         if (payload_size != 0) | ||||
|             socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); | ||||
| 
 | ||||
|         socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||
|         socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK); | ||||
|         socket->set_state(TCPSocket::State::Disconnecting); | ||||
|         socket->set_connected(false); | ||||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     socket->set_ack_number(tcp_packet.sequence_number() + payload_size); | ||||
| #ifdef TCP_DEBUG | ||||
|     kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n", | ||||
|         tcp_packet.ack_number(), | ||||
|         tcp_packet.sequence_number(), | ||||
|         payload_size, | ||||
|         socket->ack_number(), | ||||
|         socket->sequence_number()); | ||||
| #endif | ||||
|     socket->send_tcp_packet(TCPFlags::ACK); | ||||
| 
 | ||||
|     if (payload_size != 0) | ||||
|         socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); | ||||
| } | ||||
|  |  | |||
|  | @ -1,8 +1,8 @@ | |||
| #pragma once | ||||
| 
 | ||||
| #include <AK/HashTable.h> | ||||
| #include <AK/RefPtr.h> | ||||
| #include <AK/RefCounted.h> | ||||
| #include <AK/RefPtr.h> | ||||
| #include <AK/Vector.h> | ||||
| #include <Kernel/FileSystem/File.h> | ||||
| #include <Kernel/KResult.h> | ||||
|  | @ -35,10 +35,10 @@ public: | |||
|     bool can_accept() const { return !m_pending.is_empty(); } | ||||
|     RefPtr<Socket> accept(); | ||||
|     bool is_connected() const { return m_connected; } | ||||
|     KResult listen(int backlog); | ||||
| 
 | ||||
|     virtual KResult bind(const sockaddr*, socklen_t) = 0; | ||||
|     virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0; | ||||
|     virtual KResult listen(int) = 0; | ||||
|     virtual bool get_local_address(sockaddr*, socklen_t*) = 0; | ||||
|     virtual bool get_peer_address(sockaddr*, socklen_t*) = 0; | ||||
|     virtual bool is_local() const { return false; } | ||||
|  | @ -73,6 +73,9 @@ protected: | |||
|     void load_receive_deadline(); | ||||
|     void load_send_deadline(); | ||||
| 
 | ||||
|     int backlog() const { return m_backlog; } | ||||
|     void set_backlog(int backlog) { m_backlog = backlog; } | ||||
| 
 | ||||
|     virtual const char* class_name() const override { return "Socket"; } | ||||
| 
 | ||||
| private: | ||||
|  |  | |||
|  | @ -39,6 +39,7 @@ public: | |||
|     bool has_syn() const { return flags() & TCPFlags::SYN; } | ||||
|     bool has_ack() const { return flags() & TCPFlags::ACK; } | ||||
|     bool has_fin() const { return flags() & TCPFlags::FIN; } | ||||
|     bool has_rst() const { return flags() & TCPFlags::RST; } | ||||
| 
 | ||||
|     u8 data_offset() const { return (m_flags_and_data_offset & 0xf000) >> 12; } | ||||
|     void set_data_offset(u16 data_offset) { m_flags_and_data_offset = (m_flags_and_data_offset & ~0xf000) | data_offset << 12; } | ||||
|  |  | |||
|  | @ -1,28 +1,35 @@ | |||
| #include <Kernel/Devices/RandomDevice.h> | ||||
| #include <Kernel/FileSystem/FileDescription.h> | ||||
| #include <Kernel/Net/NetworkAdapter.h> | ||||
| #include <Kernel/Net/Routing.h> | ||||
| #include <Kernel/Net/TCP.h> | ||||
| #include <Kernel/Net/TCPSocket.h> | ||||
| #include <Kernel/FileSystem/FileDescription.h> | ||||
| #include <Kernel/Process.h> | ||||
| 
 | ||||
| //#define TCP_SOCKET_DEBUG
 | ||||
| 
 | ||||
| Lockable<HashMap<u16, TCPSocket*>>& TCPSocket::sockets_by_port() | ||||
| void TCPSocket::for_each(Function<void(TCPSocket*&)> callback) | ||||
| { | ||||
|     static Lockable<HashMap<u16, TCPSocket*>>* s_map; | ||||
|     LOCKER(sockets_by_tuple().lock()); | ||||
|     for (auto& it : sockets_by_tuple().resource()) | ||||
|         callback(it.value); | ||||
| } | ||||
| 
 | ||||
| Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& TCPSocket::sockets_by_tuple() | ||||
| { | ||||
|     static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>* s_map; | ||||
|     if (!s_map) | ||||
|         s_map = new Lockable<HashMap<u16, TCPSocket*>>; | ||||
|         s_map = new Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>; | ||||
|     return *s_map; | ||||
| } | ||||
| 
 | ||||
| TCPSocketHandle TCPSocket::from_port(u16 port) | ||||
| TCPSocketHandle TCPSocket::from_tuple(const IPv4SocketTuple& tuple) | ||||
| { | ||||
|     RefPtr<TCPSocket> socket; | ||||
|     { | ||||
|         LOCKER(sockets_by_port().lock()); | ||||
|         auto it = sockets_by_port().resource().find(port); | ||||
|         if (it == sockets_by_port().resource().end()) | ||||
|         LOCKER(sockets_by_tuple().lock()); | ||||
|         auto it = sockets_by_tuple().resource().find(tuple); | ||||
|         if (it == sockets_by_tuple().resource().end()) | ||||
|             return {}; | ||||
|         socket = (*it).value; | ||||
|         ASSERT(socket); | ||||
|  | @ -30,6 +37,11 @@ TCPSocketHandle TCPSocket::from_port(u16 port) | |||
|     return { move(socket) }; | ||||
| } | ||||
| 
 | ||||
| TCPSocketHandle TCPSocket::from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port) | ||||
| { | ||||
|     return from_tuple(IPv4SocketTuple(local_address, local_port, peer_address, peer_port)); | ||||
| } | ||||
| 
 | ||||
| TCPSocket::TCPSocket(int protocol) | ||||
|     : IPv4Socket(SOCK_STREAM, protocol) | ||||
| { | ||||
|  | @ -37,8 +49,8 @@ TCPSocket::TCPSocket(int protocol) | |||
| 
 | ||||
| TCPSocket::~TCPSocket() | ||||
| { | ||||
|     LOCKER(sockets_by_port().lock()); | ||||
|     sockets_by_port().resource().remove(local_port()); | ||||
|     LOCKER(sockets_by_tuple().lock()); | ||||
|     sockets_by_tuple().resource().remove(tuple()); | ||||
| } | ||||
| 
 | ||||
| NonnullRefPtr<TCPSocket> TCPSocket::create(int protocol) | ||||
|  | @ -62,18 +74,13 @@ int TCPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size | |||
| 
 | ||||
| int TCPSocket::protocol_send(const void* data, int data_length) | ||||
| { | ||||
|     auto* adapter = adapter_for_route_to(peer_address()); | ||||
|     if (!adapter) | ||||
|         return -EHOSTUNREACH; | ||||
|     send_tcp_packet(TCPFlags::PUSH | TCPFlags::ACK, data, data_length); | ||||
|     return data_length; | ||||
| } | ||||
| 
 | ||||
| void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size) | ||||
| { | ||||
|     // FIXME: Maybe the socket should be bound to an adapter instead of looking it up every time?
 | ||||
|     auto* adapter = adapter_for_route_to(peer_address()); | ||||
|     ASSERT(adapter); | ||||
|     ASSERT(m_adapter); | ||||
| 
 | ||||
|     auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size); | ||||
|     auto& tcp_packet = *(TCPPacket*)(buffer.pointer()); | ||||
|  | @ -95,19 +102,21 @@ void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size | |||
|     } | ||||
| 
 | ||||
|     memcpy(tcp_packet.payload(), payload, payload_size); | ||||
|     tcp_packet.set_checksum(compute_tcp_checksum(adapter->ipv4_address(), peer_address(), tcp_packet, payload_size)); | ||||
|     tcp_packet.set_checksum(compute_tcp_checksum(local_address(), peer_address(), tcp_packet, payload_size)); | ||||
| #ifdef TCP_SOCKET_DEBUG | ||||
|     kprintf("sending tcp packet from %s:%u to %s:%u with (%s %s) seq_no=%u, ack_no=%u\n", | ||||
|         adapter->ipv4_address().to_string().characters(), | ||||
|     kprintf("sending tcp packet from %s:%u to %s:%u with (%s%s%s%s) seq_no=%u, ack_no=%u\n", | ||||
|         local_address().to_string().characters(), | ||||
|         local_port(), | ||||
|         peer_address().to_string().characters(), | ||||
|         peer_port(), | ||||
|         tcp_packet.has_syn() ? "SYN" : "", | ||||
|         tcp_packet.has_ack() ? "ACK" : "", | ||||
|         tcp_packet.has_fin() ? "FIN" : "", | ||||
|         tcp_packet.has_rst() ? "RST" : "", | ||||
|         tcp_packet.sequence_number(), | ||||
|         tcp_packet.ack_number()); | ||||
| #endif | ||||
|     adapter->send_ipv4(MACAddress(), peer_address(), IPv4Protocol::TCP, buffer.data(), buffer.size()); | ||||
|     m_adapter->send_ipv4(MACAddress(), peer_address(), IPv4Protocol::TCP, buffer.data(), buffer.size()); | ||||
| } | ||||
| 
 | ||||
| NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket& packet, u16 payload_size) | ||||
|  | @ -152,11 +161,36 @@ NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(const IPv4Address& source, c | |||
|     return ~(checksum & 0xffff); | ||||
| } | ||||
| 
 | ||||
| KResult TCPSocket::protocol_bind() | ||||
| { | ||||
|     if (!m_adapter) { | ||||
|         m_adapter = NetworkAdapter::from_ipv4_address(local_address()); | ||||
|         if (!m_adapter) | ||||
|             return KResult(-EADDRNOTAVAIL); | ||||
|     } | ||||
| 
 | ||||
|     return KSuccess; | ||||
| } | ||||
| 
 | ||||
| KResult TCPSocket::protocol_listen() | ||||
| { | ||||
|     LOCKER(sockets_by_tuple().lock()); | ||||
|     if (sockets_by_tuple().resource().contains(tuple())) | ||||
|         return KResult(-EADDRINUSE); | ||||
|     sockets_by_tuple().resource().set(tuple(), this); | ||||
|     set_state(State::Listen); | ||||
|     return KSuccess; | ||||
| } | ||||
| 
 | ||||
| KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock should_block) | ||||
| { | ||||
|     auto* adapter = adapter_for_route_to(peer_address()); | ||||
|     if (!adapter) | ||||
|         return KResult(-EHOSTUNREACH); | ||||
|     if (!m_adapter) { | ||||
|         m_adapter = adapter_for_route_to(peer_address()); | ||||
|         if (!m_adapter) | ||||
|             return KResult(-EHOSTUNREACH); | ||||
| 
 | ||||
|         set_local_address(m_adapter->ipv4_address()); | ||||
|     } | ||||
| 
 | ||||
|     allocate_local_port_if_needed(); | ||||
| 
 | ||||
|  | @ -164,7 +198,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh | |||
|     m_ack_number = 0; | ||||
| 
 | ||||
|     send_tcp_packet(TCPFlags::SYN); | ||||
|     m_state = State::Connecting; | ||||
|     m_state = State::SynSent; | ||||
| 
 | ||||
|     if (should_block == ShouldBlock::Yes) { | ||||
|         if (current->block<Thread::ConnectBlocker>(description) == Thread::BlockResult::InterruptedBySignal) | ||||
|  | @ -183,12 +217,14 @@ int TCPSocket::protocol_allocate_local_port() | |||
|     static const u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; | ||||
|     u16 first_scan_port = first_ephemeral_port + RandomDevice::random_value() % ephemeral_port_range_size; | ||||
| 
 | ||||
|     LOCKER(sockets_by_port().lock()); | ||||
|     LOCKER(sockets_by_tuple().lock()); | ||||
|     for (u16 port = first_scan_port;;) { | ||||
|         auto it = sockets_by_port().resource().find(port); | ||||
|         if (it == sockets_by_port().resource().end()) { | ||||
|         IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); | ||||
| 
 | ||||
|         auto it = sockets_by_tuple().resource().find(proposed_tuple); | ||||
|         if (it == sockets_by_tuple().resource().end()) { | ||||
|             set_local_port(port); | ||||
|             sockets_by_port().resource().set(port, this); | ||||
|             sockets_by_tuple().resource().set(proposed_tuple, this); | ||||
|             return port; | ||||
|         } | ||||
|         ++port; | ||||
|  | @ -202,14 +238,16 @@ int TCPSocket::protocol_allocate_local_port() | |||
| 
 | ||||
| bool TCPSocket::protocol_is_disconnected() const | ||||
| { | ||||
|     return m_state == State::Disconnecting || m_state == State::Disconnected; | ||||
| } | ||||
| 
 | ||||
| KResult TCPSocket::protocol_bind() | ||||
| { | ||||
|     LOCKER(sockets_by_port().lock()); | ||||
|     if (sockets_by_port().resource().contains(local_port())) | ||||
|         return KResult(-EADDRINUSE); | ||||
|     sockets_by_port().resource().set(local_port(), this); | ||||
|     return KSuccess; | ||||
|     switch (m_state) { | ||||
|     case State::Closed: | ||||
|     case State::CloseWait: | ||||
|     case State::LastAck: | ||||
|     case State::FinWait1: | ||||
|     case State::FinWait2: | ||||
|     case State::Closing: | ||||
|     case State::TimeWait: | ||||
|         return true; | ||||
|     default: | ||||
|         return false; | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -1,19 +1,58 @@ | |||
| #pragma once | ||||
| 
 | ||||
| #include <AK/Function.h> | ||||
| #include <Kernel/Net/IPv4Socket.h> | ||||
| 
 | ||||
| class TCPSocket final : public IPv4Socket { | ||||
| public: | ||||
|     static void for_each(Function<void(TCPSocket*&)>); | ||||
|     static NonnullRefPtr<TCPSocket> create(int protocol); | ||||
|     virtual ~TCPSocket() override; | ||||
| 
 | ||||
|     enum class State { | ||||
|         Disconnected, | ||||
|         Connecting, | ||||
|         Connected, | ||||
|         Disconnecting, | ||||
|         Closed, | ||||
|         Listen, | ||||
|         SynSent, | ||||
|         SynReceived, | ||||
|         Established, | ||||
|         CloseWait, | ||||
|         LastAck, | ||||
|         FinWait1, | ||||
|         FinWait2, | ||||
|         Closing, | ||||
|         TimeWait, | ||||
|     }; | ||||
| 
 | ||||
|     static const char* to_string(State state) | ||||
|     { | ||||
|         switch (state) { | ||||
|         case State::Closed: | ||||
|             return "Closed"; | ||||
|         case State::Listen: | ||||
|             return "Listen"; | ||||
|         case State::SynSent: | ||||
|             return "SynSent"; | ||||
|         case State::SynReceived: | ||||
|             return "SynReceived"; | ||||
|         case State::Established: | ||||
|             return "Established"; | ||||
|         case State::CloseWait: | ||||
|             return "CloseWait"; | ||||
|         case State::LastAck: | ||||
|             return "LastAck"; | ||||
|         case State::FinWait1: | ||||
|             return "FinWait1"; | ||||
|         case State::FinWait2: | ||||
|             return "FinWait2"; | ||||
|         case State::Closing: | ||||
|             return "Closing"; | ||||
|         case State::TimeWait: | ||||
|             return "TimeWait"; | ||||
|         default: | ||||
|             return "None"; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     State state() const { return m_state; } | ||||
|     void set_state(State state) { m_state = state; } | ||||
| 
 | ||||
|  | @ -24,8 +63,9 @@ public: | |||
| 
 | ||||
|     void send_tcp_packet(u16 flags, const void* = nullptr, int = 0); | ||||
| 
 | ||||
|     static Lockable<HashMap<u16, TCPSocket*>>& sockets_by_port(); | ||||
|     static TCPSocketHandle from_port(u16); | ||||
|     static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& sockets_by_tuple(); | ||||
|     static TCPSocketHandle from_tuple(const IPv4SocketTuple& tuple); | ||||
|     static TCPSocketHandle from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port); | ||||
| 
 | ||||
| private: | ||||
|     explicit TCPSocket(int protocol); | ||||
|  | @ -39,10 +79,12 @@ private: | |||
|     virtual int protocol_allocate_local_port() override; | ||||
|     virtual bool protocol_is_disconnected() const override; | ||||
|     virtual KResult protocol_bind() override; | ||||
|     virtual KResult protocol_listen() override; | ||||
| 
 | ||||
|     NetworkAdapter* m_adapter { nullptr }; | ||||
|     u32 m_sequence_number { 0 }; | ||||
|     u32 m_ack_number { 0 }; | ||||
|     State m_state { State::Disconnected }; | ||||
|     State m_state { State::Closed }; | ||||
| }; | ||||
| 
 | ||||
| class TCPSocketHandle : public SocketHandle { | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Conrad Pankoff
						Conrad Pankoff