mirror of
				https://github.com/RGBCube/serenity
				synced 2025-10-31 06:32:44 +00:00 
			
		
		
		
	Kernel: Refactor TCP/IP stack
This has several significant changes to the networking stack. * Significant refactoring of the TCP state machine. Right now it's probably more fragile than it used to be, but handles quite a lot more of the handshake process. * `TCPSocket` holds a `NetworkAdapter*`, assigned during `connect()` or `bind()`, whichever comes first. * `listen()` is now virtual in `Socket` and intended to be implemented in its child classes * `listen()` no longer works without `bind()` - this is a bit of a regression, but listening sockets didn't work at all before, so it's not possible to observe the regression. * A file is exposed at `/proc/net_tcp`, which is a JSON document listing the current TCP sockets with a bit of metadata. * There's an `ETHERNET_VERY_DEBUG` flag for dumping packet's content out to `kprintf`. It is, indeed, _very debug_.
This commit is contained in:
		
							parent
							
								
									c973a51a23
								
							
						
					
					
						commit
						73c998dbfc
					
				
					 12 changed files with 446 additions and 84 deletions
				
			
		|  | @ -67,6 +67,7 @@ public: | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     in_addr_t to_in_addr_t() const { return m_data_as_u32; } |     in_addr_t to_in_addr_t() const { return m_data_as_u32; } | ||||||
|  |     u32 to_u32() const { return m_data_as_u32; } | ||||||
| 
 | 
 | ||||||
|     bool operator==(const IPv4Address& other) const { return m_data_as_u32 == other.m_data_as_u32; } |     bool operator==(const IPv4Address& other) const { return m_data_as_u32 == other.m_data_as_u32; } | ||||||
|     bool operator!=(const IPv4Address& other) const { return m_data_as_u32 != other.m_data_as_u32; } |     bool operator!=(const IPv4Address& other) const { return m_data_as_u32 != other.m_data_as_u32; } | ||||||
|  |  | ||||||
|  | @ -14,6 +14,7 @@ | ||||||
| #include <Kernel/FileSystem/VirtualFileSystem.h> | #include <Kernel/FileSystem/VirtualFileSystem.h> | ||||||
| #include <Kernel/KParams.h> | #include <Kernel/KParams.h> | ||||||
| #include <Kernel/Net/NetworkAdapter.h> | #include <Kernel/Net/NetworkAdapter.h> | ||||||
|  | #include <Kernel/Net/TCPSocket.h> | ||||||
| #include <Kernel/PCI.h> | #include <Kernel/PCI.h> | ||||||
| #include <Kernel/VM/MemoryManager.h> | #include <Kernel/VM/MemoryManager.h> | ||||||
| #include <Kernel/kmalloc.h> | #include <Kernel/kmalloc.h> | ||||||
|  | @ -46,6 +47,7 @@ enum ProcFileType { | ||||||
|     FI_Root_uptime, |     FI_Root_uptime, | ||||||
|     FI_Root_cmdline, |     FI_Root_cmdline, | ||||||
|     FI_Root_netadapters, |     FI_Root_netadapters, | ||||||
|  |     FI_Root_net_tcp, | ||||||
|     FI_Root_self, // symlink
 |     FI_Root_self, // symlink
 | ||||||
|     FI_Root_sys,  // directory
 |     FI_Root_sys,  // directory
 | ||||||
|     __FI_Root_End, |     __FI_Root_End, | ||||||
|  | @ -278,6 +280,23 @@ Optional<KBuffer> procfs$netadapters(InodeIdentifier) | ||||||
|     return builder.to_byte_buffer(); |     return builder.to_byte_buffer(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | Optional<KBuffer> procfs$net_tcp(InodeIdentifier) | ||||||
|  | { | ||||||
|  |     JsonArray json; | ||||||
|  |     TCPSocket::for_each([&json](auto& socket) { | ||||||
|  |         JsonObject obj; | ||||||
|  |         obj.set("local_address", socket->local_address().to_string()); | ||||||
|  |         obj.set("local_port", socket->local_port()); | ||||||
|  |         obj.set("peer_address", socket->peer_address().to_string()); | ||||||
|  |         obj.set("peer_port", socket->peer_port()); | ||||||
|  |         obj.set("state", TCPSocket::to_string(socket->state())); | ||||||
|  |         obj.set("ack_number", socket->ack_number()); | ||||||
|  |         obj.set("sequence_number", socket->sequence_number()); | ||||||
|  |         json.append(obj); | ||||||
|  |     }); | ||||||
|  |     return json.serialized().to_byte_buffer(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| Optional<KBuffer> procfs$pid_vmo(InodeIdentifier identifier) | Optional<KBuffer> procfs$pid_vmo(InodeIdentifier identifier) | ||||||
| { | { | ||||||
|     auto handle = ProcessInspectionHandle::from_pid(to_pid(identifier)); |     auto handle = ProcessInspectionHandle::from_pid(to_pid(identifier)); | ||||||
|  | @ -1077,6 +1096,7 @@ ProcFS::ProcFS() | ||||||
|     m_entries[FI_Root_uptime] = { "uptime", FI_Root_uptime, procfs$uptime }; |     m_entries[FI_Root_uptime] = { "uptime", FI_Root_uptime, procfs$uptime }; | ||||||
|     m_entries[FI_Root_cmdline] = { "cmdline", FI_Root_cmdline, procfs$cmdline }; |     m_entries[FI_Root_cmdline] = { "cmdline", FI_Root_cmdline, procfs$cmdline }; | ||||||
|     m_entries[FI_Root_netadapters] = { "netadapters", FI_Root_netadapters, procfs$netadapters }; |     m_entries[FI_Root_netadapters] = { "netadapters", FI_Root_netadapters, procfs$netadapters }; | ||||||
|  |     m_entries[FI_Root_net_tcp] = { "net_tcp", FI_Root_net_tcp, procfs$net_tcp }; | ||||||
|     m_entries[FI_Root_sys] = { "sys", FI_Root_sys }; |     m_entries[FI_Root_sys] = { "sys", FI_Root_sys }; | ||||||
| 
 | 
 | ||||||
|     m_entries[FI_PID_vm] = { "vm", FI_PID_vm, procfs$pid_vm }; |     m_entries[FI_PID_vm] = { "vm", FI_PID_vm, procfs$pid_vm }; | ||||||
|  |  | ||||||
|  | @ -89,6 +89,22 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size) | ||||||
|     return protocol_bind(); |     return protocol_bind(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | KResult IPv4Socket::listen(int backlog) | ||||||
|  | { | ||||||
|  |     int rc = allocate_local_port_if_needed(); | ||||||
|  |     if (rc < 0) | ||||||
|  |         return KResult(-EADDRINUSE); | ||||||
|  | 
 | ||||||
|  |     if (m_local_address.to_u32() == 0) | ||||||
|  |         return KResult(-EADDRINUSE); | ||||||
|  | 
 | ||||||
|  |     set_backlog(backlog); | ||||||
|  | 
 | ||||||
|  |     kprintf("IPv4Socket{%p} listening with backlog=%d\n", this, backlog); | ||||||
|  | 
 | ||||||
|  |     return protocol_listen(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| KResult IPv4Socket::connect(FileDescription& description, const sockaddr* address, socklen_t address_size, ShouldBlock should_block) | KResult IPv4Socket::connect(FileDescription& description, const sockaddr* address, socklen_t address_size, ShouldBlock should_block) | ||||||
| { | { | ||||||
|     if (address_size != sizeof(sockaddr_in)) |     if (address_size != sizeof(sockaddr_in)) | ||||||
|  | @ -157,6 +173,9 @@ ssize_t IPv4Socket::sendto(FileDescription&, const void* data, size_t data_lengt | ||||||
|     if (!adapter) |     if (!adapter) | ||||||
|         return -EHOSTUNREACH; |         return -EHOSTUNREACH; | ||||||
| 
 | 
 | ||||||
|  |     if (m_local_address.to_u32() == 0) | ||||||
|  |         m_local_address = adapter->ipv4_address(); | ||||||
|  | 
 | ||||||
|     int rc = allocate_local_port_if_needed(); |     int rc = allocate_local_port_if_needed(); | ||||||
|     if (rc < 0) |     if (rc < 0) | ||||||
|         return rc; |         return rc; | ||||||
|  |  | ||||||
|  | @ -2,10 +2,11 @@ | ||||||
| 
 | 
 | ||||||
| #include <AK/HashMap.h> | #include <AK/HashMap.h> | ||||||
| #include <AK/SinglyLinkedList.h> | #include <AK/SinglyLinkedList.h> | ||||||
| #include <Kernel/KBuffer.h> |  | ||||||
| #include <Kernel/DoubleBuffer.h> | #include <Kernel/DoubleBuffer.h> | ||||||
|  | #include <Kernel/KBuffer.h> | ||||||
| #include <Kernel/Lock.h> | #include <Kernel/Lock.h> | ||||||
| #include <Kernel/Net/IPv4.h> | #include <Kernel/Net/IPv4.h> | ||||||
|  | #include <Kernel/Net/IPv4SocketTuple.h> | ||||||
| #include <Kernel/Net/Socket.h> | #include <Kernel/Net/Socket.h> | ||||||
| 
 | 
 | ||||||
| class IPv4SocketHandle; | class IPv4SocketHandle; | ||||||
|  | @ -23,6 +24,7 @@ public: | ||||||
| 
 | 
 | ||||||
|     virtual KResult bind(const sockaddr*, socklen_t) override; |     virtual KResult bind(const sockaddr*, socklen_t) override; | ||||||
|     virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; |     virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; | ||||||
|  |     virtual KResult listen(int) override; | ||||||
|     virtual bool get_local_address(sockaddr*, socklen_t*) override; |     virtual bool get_local_address(sockaddr*, socklen_t*) override; | ||||||
|     virtual bool get_peer_address(sockaddr*, socklen_t*) override; |     virtual bool get_peer_address(sockaddr*, socklen_t*) override; | ||||||
|     virtual void attach(FileDescription&) override; |     virtual void attach(FileDescription&) override; | ||||||
|  | @ -34,7 +36,7 @@ public: | ||||||
| 
 | 
 | ||||||
|     void did_receive(const IPv4Address& peer_address, u16 peer_port, KBuffer&&); |     void did_receive(const IPv4Address& peer_address, u16 peer_port, KBuffer&&); | ||||||
| 
 | 
 | ||||||
|     const IPv4Address& local_address() const; |     const IPv4Address& local_address() const { return m_local_address; } | ||||||
|     u16 local_port() const { return m_local_port; } |     u16 local_port() const { return m_local_port; } | ||||||
|     void set_local_port(u16 port) { m_local_port = port; } |     void set_local_port(u16 port) { m_local_port = port; } | ||||||
| 
 | 
 | ||||||
|  | @ -42,6 +44,8 @@ public: | ||||||
|     u16 peer_port() const { return m_peer_port; } |     u16 peer_port() const { return m_peer_port; } | ||||||
|     void set_peer_port(u16 port) { m_peer_port = port; } |     void set_peer_port(u16 port) { m_peer_port = port; } | ||||||
| 
 | 
 | ||||||
|  |     IPv4SocketTuple tuple() const { return IPv4SocketTuple(m_local_address, m_local_port, m_peer_address, m_peer_port); } | ||||||
|  | 
 | ||||||
| protected: | protected: | ||||||
|     IPv4Socket(int type, int protocol); |     IPv4Socket(int type, int protocol); | ||||||
|     virtual const char* class_name() const override { return "IPv4Socket"; } |     virtual const char* class_name() const override { return "IPv4Socket"; } | ||||||
|  | @ -49,12 +53,16 @@ protected: | ||||||
|     int allocate_local_port_if_needed(); |     int allocate_local_port_if_needed(); | ||||||
| 
 | 
 | ||||||
|     virtual KResult protocol_bind() { return KSuccess; } |     virtual KResult protocol_bind() { return KSuccess; } | ||||||
|  |     virtual KResult protocol_listen() { return KSuccess; } | ||||||
|     virtual int protocol_receive(const KBuffer&, void*, size_t, int) { return -ENOTIMPL; } |     virtual int protocol_receive(const KBuffer&, void*, size_t, int) { return -ENOTIMPL; } | ||||||
|     virtual int protocol_send(const void*, int) { return -ENOTIMPL; } |     virtual int protocol_send(const void*, int) { return -ENOTIMPL; } | ||||||
|     virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; } |     virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; } | ||||||
|     virtual int protocol_allocate_local_port() { return 0; } |     virtual int protocol_allocate_local_port() { return 0; } | ||||||
|     virtual bool protocol_is_disconnected() const { return false; } |     virtual bool protocol_is_disconnected() const { return false; } | ||||||
| 
 | 
 | ||||||
|  |     void set_local_address(IPv4Address address) { m_local_address = address; } | ||||||
|  |     void set_peer_address(IPv4Address address) { m_peer_address = address; } | ||||||
|  | 
 | ||||||
| private: | private: | ||||||
|     virtual bool is_ipv4() const override { return true; } |     virtual bool is_ipv4() const override { return true; } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										63
									
								
								Kernel/Net/IPv4SocketTuple.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								Kernel/Net/IPv4SocketTuple.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,63 @@ | ||||||
|  | #pragma once | ||||||
|  | 
 | ||||||
|  | #include <AK/HashMap.h> | ||||||
|  | #include <AK/SinglyLinkedList.h> | ||||||
|  | #include <Kernel/DoubleBuffer.h> | ||||||
|  | #include <Kernel/KBuffer.h> | ||||||
|  | #include <Kernel/Lock.h> | ||||||
|  | #include <Kernel/Net/IPv4.h> | ||||||
|  | #include <Kernel/Net/Socket.h> | ||||||
|  | 
 | ||||||
|  | class IPv4SocketTuple { | ||||||
|  | public: | ||||||
|  |     IPv4SocketTuple(IPv4Address local_address, u16 local_port, IPv4Address peer_address, u16 peer_port) | ||||||
|  |         : m_local_address(local_address) | ||||||
|  |         , m_local_port(local_port) | ||||||
|  |         , m_peer_address(peer_address) | ||||||
|  |         , m_peer_port(peer_port) {}; | ||||||
|  | 
 | ||||||
|  |     IPv4Address local_address() const { return m_local_address; }; | ||||||
|  |     u16 local_port() const { return m_local_port; }; | ||||||
|  |     IPv4Address peer_address() const { return m_peer_address; }; | ||||||
|  |     u16 peer_port() const { return m_peer_port; }; | ||||||
|  | 
 | ||||||
|  |     bool operator==(const IPv4SocketTuple other) const | ||||||
|  |     { | ||||||
|  |         return other.local_address() == m_local_address && other.local_port() == m_local_port && other.peer_address() == m_peer_address && other.peer_port() == m_peer_port; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     String to_string() const | ||||||
|  |     { | ||||||
|  |         return String::format( | ||||||
|  |             "%s:%d -> %s:%d", | ||||||
|  |             m_local_address.to_string().characters(), | ||||||
|  |             m_local_port, | ||||||
|  |             m_peer_address.to_string().characters(), | ||||||
|  |             m_peer_port); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | private: | ||||||
|  |     IPv4Address m_local_address; | ||||||
|  |     u16 m_local_port { 0 }; | ||||||
|  |     IPv4Address m_peer_address; | ||||||
|  |     u16 m_peer_port { 0 }; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | namespace AK { | ||||||
|  | 
 | ||||||
|  | template<> | ||||||
|  | struct Traits<IPv4SocketTuple> : public GenericTraits<IPv4SocketTuple> { | ||||||
|  |     static unsigned hash(const IPv4SocketTuple& tuple) | ||||||
|  |     { | ||||||
|  |         auto h1 = pair_int_hash(tuple.local_address().to_u32(), tuple.local_port()); | ||||||
|  |         auto h2 = pair_int_hash(tuple.peer_address().to_u32(), tuple.peer_port()); | ||||||
|  |         return pair_int_hash(h1, h2); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     static void dump(const IPv4SocketTuple& tuple) | ||||||
|  |     { | ||||||
|  |         kprintf("%s", tuple.to_string().characters()); | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | @ -114,6 +114,16 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre | ||||||
|     return KSuccess; |     return KSuccess; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | KResult LocalSocket::listen(int backlog) | ||||||
|  | { | ||||||
|  |     LOCKER(lock()); | ||||||
|  |     if (type() != SOCK_STREAM) | ||||||
|  |         return KResult(-EOPNOTSUPP); | ||||||
|  |     set_backlog(backlog); | ||||||
|  |     kprintf("LocalSocket{%p} listening with backlog=%d\n", this, backlog); | ||||||
|  |     return KSuccess; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| void LocalSocket::attach(FileDescription& description) | void LocalSocket::attach(FileDescription& description) | ||||||
| { | { | ||||||
|     switch (description.socket_role()) { |     switch (description.socket_role()) { | ||||||
|  |  | ||||||
|  | @ -13,6 +13,7 @@ public: | ||||||
|     // ^Socket
 |     // ^Socket
 | ||||||
|     virtual KResult bind(const sockaddr*, socklen_t) override; |     virtual KResult bind(const sockaddr*, socklen_t) override; | ||||||
|     virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; |     virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; | ||||||
|  |     virtual KResult listen(int) override; | ||||||
|     virtual bool get_local_address(sockaddr*, socklen_t*) override; |     virtual bool get_local_address(sockaddr*, socklen_t*) override; | ||||||
|     virtual bool get_peer_address(sockaddr*, socklen_t*) override; |     virtual bool get_peer_address(sockaddr*, socklen_t*) override; | ||||||
|     virtual void attach(FileDescription&) override; |     virtual void attach(FileDescription&) override; | ||||||
|  |  | ||||||
|  | @ -14,6 +14,7 @@ | ||||||
| #include <Kernel/Process.h> | #include <Kernel/Process.h> | ||||||
| 
 | 
 | ||||||
| //#define ETHERNET_DEBUG
 | //#define ETHERNET_DEBUG
 | ||||||
|  | //#define ETHERNET_VERY_DEBUG
 | ||||||
| //#define IPV4_DEBUG
 | //#define IPV4_DEBUG
 | ||||||
| //#define ICMP_DEBUG
 | //#define ICMP_DEBUG
 | ||||||
| //#define UDP_DEBUG
 | //#define UDP_DEBUG
 | ||||||
|  | @ -84,6 +85,28 @@ void NetworkTask_main() | ||||||
|             packet.size()); |             packet.size()); | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|  | #ifdef ETHERNET_VERY_DEBUG | ||||||
|  |         u8* data = packet.data(); | ||||||
|  | 
 | ||||||
|  |         for (size_t i = 0; i < packet.size(); i++) { | ||||||
|  |             kprintf("%b", data[i]); | ||||||
|  | 
 | ||||||
|  |             switch (i % 16) { | ||||||
|  |             case 7: | ||||||
|  |                 kprintf("  "); | ||||||
|  |                 break; | ||||||
|  |             case 15: | ||||||
|  |                 kprintf("\n"); | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 kprintf(" "); | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         kprintf("\n"); | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|         switch (eth.ether_type()) { |         switch (eth.ether_type()) { | ||||||
|         case EtherType::ARP: |         case EtherType::ARP: | ||||||
|             handle_arp(eth, packet.size()); |             handle_arp(eth, packet.size()); | ||||||
|  | @ -279,7 +302,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) | ||||||
|     size_t payload_size = ipv4_packet.payload_size() - tcp_packet.header_size(); |     size_t payload_size = ipv4_packet.payload_size() - tcp_packet.header_size(); | ||||||
| 
 | 
 | ||||||
| #ifdef TCP_DEBUG | #ifdef TCP_DEBUG | ||||||
|     kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s %s), window_size=%u, payload_size=%u\n", |     kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s%s%s%s), window_size=%u, payload_size=%u\n", | ||||||
|         ipv4_packet.source().to_string().characters(), |         ipv4_packet.source().to_string().characters(), | ||||||
|         tcp_packet.source_port(), |         tcp_packet.source_port(), | ||||||
|         ipv4_packet.destination().to_string().characters(), |         ipv4_packet.destination().to_string().characters(), | ||||||
|  | @ -287,15 +310,19 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) | ||||||
|         tcp_packet.sequence_number(), |         tcp_packet.sequence_number(), | ||||||
|         tcp_packet.ack_number(), |         tcp_packet.ack_number(), | ||||||
|         tcp_packet.flags(), |         tcp_packet.flags(), | ||||||
|         tcp_packet.has_syn() ? "SYN" : "", |         tcp_packet.has_syn() ? "SYN " : "", | ||||||
|         tcp_packet.has_ack() ? "ACK" : "", |         tcp_packet.has_ack() ? "ACK " : "", | ||||||
|  |         tcp_packet.has_fin() ? "FIN " : "", | ||||||
|  |         tcp_packet.has_rst() ? "RST " : "", | ||||||
|         tcp_packet.window_size(), |         tcp_packet.window_size(), | ||||||
|         payload_size); |         payload_size); | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|     auto socket = TCPSocket::from_port(tcp_packet.destination_port()); |     IPv4SocketTuple tuple(ipv4_packet.destination(), tcp_packet.destination_port(), ipv4_packet.source(), tcp_packet.source_port()); | ||||||
|  | 
 | ||||||
|  |     auto socket = TCPSocket::from_tuple(tuple); | ||||||
|     if (!socket) { |     if (!socket) { | ||||||
|         kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port()); |         kprintf("handle_tcp: No TCP socket for tuple %s\n", tuple.to_string().characters()); | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -307,39 +334,168 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     if (tcp_packet.has_syn() && tcp_packet.has_ack()) { | #ifdef TCP_DEBUG | ||||||
|         socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); |     kprintf("handle_tcp: state=%s\n", TCPSocket::to_string(socket->state())); | ||||||
|         socket->send_tcp_packet(TCPFlags::ACK); | #endif | ||||||
|         socket->set_connected(true); |  | ||||||
|         kprintf("handle_tcp: Connection established!\n"); |  | ||||||
|         socket->set_state(TCPSocket::State::Connected); |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     if (tcp_packet.has_fin()) { |     switch (socket->state()) { | ||||||
|         kprintf("handle_tcp: Got FIN, payload_size=%u\n", payload_size); |     case TCPSocket::State::Closed: | ||||||
|  |         kprintf("handle_tcp: unexpected flags in Closed state\n"); | ||||||
|  |         socket->send_tcp_packet(TCPFlags::RST); | ||||||
|  |         socket->set_state(TCPSocket::State::Closed); | ||||||
|  |         kprintf("handle_tcp: Closed -> Closed\n"); | ||||||
|  |         return; | ||||||
|  |     case TCPSocket::State::TimeWait: | ||||||
|  |         kprintf("handle_tcp: unexpected flags in TimeWait state\n"); | ||||||
|  |         socket->send_tcp_packet(TCPFlags::RST); | ||||||
|  |         socket->set_state(TCPSocket::State::Closed); | ||||||
|  |         kprintf("handle_tcp: TimeWait -> Closed\n"); | ||||||
|  |         return; | ||||||
|  |     case TCPSocket::State::Listen: | ||||||
|  |         switch (tcp_packet.flags()) { | ||||||
|  |         case TCPFlags::SYN: | ||||||
|  |             kprintf("handle_tcp: incoming connections not supported\n"); | ||||||
|  |             // socket->send_tcp_packet(TCPFlags::RST);
 | ||||||
|  |             return; | ||||||
|  |         default: | ||||||
|  |             kprintf("handle_tcp: unexpected flags in Listen state\n"); | ||||||
|  |             // socket->send_tcp_packet(TCPFlags::RST);
 | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |     case TCPSocket::State::SynSent: | ||||||
|  |         switch (tcp_packet.flags()) { | ||||||
|  |         case TCPFlags::SYN: | ||||||
|  |             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||||
|  |             socket->send_tcp_packet(TCPFlags::ACK); | ||||||
|  |             socket->set_state(TCPSocket::State::SynReceived); | ||||||
|  |             kprintf("handle_tcp: SynSent -> SynReceived\n"); | ||||||
|  |             return; | ||||||
|  |         case TCPFlags::SYN | TCPFlags::ACK: | ||||||
|  |             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||||
|  |             socket->send_tcp_packet(TCPFlags::ACK); | ||||||
|  |             socket->set_state(TCPSocket::State::Established); | ||||||
|  |             socket->set_connected(true); | ||||||
|  |             kprintf("handle_tcp: SynSent -> Established\n"); | ||||||
|  |             return; | ||||||
|  |         default: | ||||||
|  |             kprintf("handle_tcp: unexpected flags in SynSent state\n"); | ||||||
|  |             socket->send_tcp_packet(TCPFlags::RST); | ||||||
|  |             socket->set_state(TCPSocket::State::Closed); | ||||||
|  |             kprintf("handle_tcp: SynSent -> Closed\n"); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |     case TCPSocket::State::SynReceived: | ||||||
|  |         switch (tcp_packet.flags()) { | ||||||
|  |         case TCPFlags::ACK: | ||||||
|  |             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||||
|  |             socket->set_state(TCPSocket::State::Established); | ||||||
|  |             socket->set_connected(true); | ||||||
|  |             kprintf("handle_tcp: SynReceived -> Established\n"); | ||||||
|  |             return; | ||||||
|  |         default: | ||||||
|  |             kprintf("handle_tcp: unexpected flags in SynReceived state\n"); | ||||||
|  |             socket->send_tcp_packet(TCPFlags::RST); | ||||||
|  |             socket->set_state(TCPSocket::State::Closed); | ||||||
|  |             kprintf("handle_tcp: SynReceived -> Closed\n"); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |     case TCPSocket::State::CloseWait: | ||||||
|  |         switch (tcp_packet.flags()) { | ||||||
|  |         default: | ||||||
|  |             kprintf("handle_tcp: unexpected flags in CloseWait state\n"); | ||||||
|  |             socket->send_tcp_packet(TCPFlags::RST); | ||||||
|  |             socket->set_state(TCPSocket::State::Closed); | ||||||
|  |             kprintf("handle_tcp: CloseWait -> Closed\n"); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |     case TCPSocket::State::LastAck: | ||||||
|  |         switch (tcp_packet.flags()) { | ||||||
|  |         case TCPFlags::ACK: | ||||||
|  |             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||||
|  |             socket->set_state(TCPSocket::State::Closed); | ||||||
|  |             kprintf("handle_tcp: LastAck -> Closed\n"); | ||||||
|  |             return; | ||||||
|  |         default: | ||||||
|  |             kprintf("handle_tcp: unexpected flags in LastAck state\n"); | ||||||
|  |             socket->send_tcp_packet(TCPFlags::RST); | ||||||
|  |             socket->set_state(TCPSocket::State::Closed); | ||||||
|  |             kprintf("handle_tcp: LastAck -> Closed\n"); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |     case TCPSocket::State::FinWait1: | ||||||
|  |         switch (tcp_packet.flags()) { | ||||||
|  |         case TCPFlags::ACK: | ||||||
|  |             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||||
|  |             socket->set_state(TCPSocket::State::FinWait2); | ||||||
|  |             kprintf("handle_tcp: FinWait1 -> FinWait2\n"); | ||||||
|  |             return; | ||||||
|  |         case TCPFlags::FIN: | ||||||
|  |             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||||
|  |             socket->set_state(TCPSocket::State::Closing); | ||||||
|  |             kprintf("handle_tcp: FinWait1 -> Closing\n"); | ||||||
|  |             return; | ||||||
|  |         default: | ||||||
|  |             kprintf("handle_tcp: unexpected flags in FinWait1 state\n"); | ||||||
|  |             socket->send_tcp_packet(TCPFlags::RST); | ||||||
|  |             socket->set_state(TCPSocket::State::Closed); | ||||||
|  |             kprintf("handle_tcp: FinWait1 -> Closed\n"); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |     case TCPSocket::State::FinWait2: | ||||||
|  |         switch (tcp_packet.flags()) { | ||||||
|  |         case TCPFlags::FIN: | ||||||
|  |             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||||
|  |             socket->set_state(TCPSocket::State::TimeWait); | ||||||
|  |             kprintf("handle_tcp: FinWait2 -> TimeWait\n"); | ||||||
|  |             return; | ||||||
|  |         default: | ||||||
|  |             kprintf("handle_tcp: unexpected flags in FinWait2 state\n"); | ||||||
|  |             socket->send_tcp_packet(TCPFlags::RST); | ||||||
|  |             socket->set_state(TCPSocket::State::Closed); | ||||||
|  |             kprintf("handle_tcp: FinWait2 -> Closed\n"); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |     case TCPSocket::State::Closing: | ||||||
|  |         switch (tcp_packet.flags()) { | ||||||
|  |         case TCPFlags::ACK: | ||||||
|  |             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||||
|  |             socket->set_state(TCPSocket::State::TimeWait); | ||||||
|  |             kprintf("handle_tcp: Closing -> TimeWait\n"); | ||||||
|  |             return; | ||||||
|  |         default: | ||||||
|  |             kprintf("handle_tcp: unexpected flags in Closing state\n"); | ||||||
|  |             socket->send_tcp_packet(TCPFlags::RST); | ||||||
|  |             socket->set_state(TCPSocket::State::Closed); | ||||||
|  |             kprintf("handle_tcp: Closing -> Closed\n"); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |     case TCPSocket::State::Established: | ||||||
|  |         if (tcp_packet.has_fin()) { | ||||||
|  |             if (payload_size != 0) | ||||||
|  |                 socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); | ||||||
|  | 
 | ||||||
|  |             socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); | ||||||
|  |             socket->send_tcp_packet(TCPFlags::ACK); | ||||||
|  |             socket->set_state(TCPSocket::State::CloseWait); | ||||||
|  |             socket->set_connected(false); | ||||||
|  |             kprintf("handle_tcp: Established -> CloseWait\n"); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         socket->set_ack_number(tcp_packet.sequence_number() + payload_size); | ||||||
|  | 
 | ||||||
|  | #ifdef TCP_DEBUG | ||||||
|  |         kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n", | ||||||
|  |             tcp_packet.ack_number(), | ||||||
|  |             tcp_packet.sequence_number(), | ||||||
|  |             payload_size, | ||||||
|  |             socket->ack_number(), | ||||||
|  |             socket->sequence_number()); | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  |         socket->send_tcp_packet(TCPFlags::ACK); | ||||||
| 
 | 
 | ||||||
|         if (payload_size != 0) |         if (payload_size != 0) | ||||||
|             socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); |             socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); | ||||||
| 
 |  | ||||||
|         socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); |  | ||||||
|         socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK); |  | ||||||
|         socket->set_state(TCPSocket::State::Disconnecting); |  | ||||||
|         socket->set_connected(false); |  | ||||||
|         return; |  | ||||||
|     } |     } | ||||||
| 
 |  | ||||||
|     socket->set_ack_number(tcp_packet.sequence_number() + payload_size); |  | ||||||
| #ifdef TCP_DEBUG |  | ||||||
|     kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n", |  | ||||||
|         tcp_packet.ack_number(), |  | ||||||
|         tcp_packet.sequence_number(), |  | ||||||
|         payload_size, |  | ||||||
|         socket->ack_number(), |  | ||||||
|         socket->sequence_number()); |  | ||||||
| #endif |  | ||||||
|     socket->send_tcp_packet(TCPFlags::ACK); |  | ||||||
| 
 |  | ||||||
|     if (payload_size != 0) |  | ||||||
|         socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,8 +1,8 @@ | ||||||
| #pragma once | #pragma once | ||||||
| 
 | 
 | ||||||
| #include <AK/HashTable.h> | #include <AK/HashTable.h> | ||||||
| #include <AK/RefPtr.h> |  | ||||||
| #include <AK/RefCounted.h> | #include <AK/RefCounted.h> | ||||||
|  | #include <AK/RefPtr.h> | ||||||
| #include <AK/Vector.h> | #include <AK/Vector.h> | ||||||
| #include <Kernel/FileSystem/File.h> | #include <Kernel/FileSystem/File.h> | ||||||
| #include <Kernel/KResult.h> | #include <Kernel/KResult.h> | ||||||
|  | @ -35,10 +35,10 @@ public: | ||||||
|     bool can_accept() const { return !m_pending.is_empty(); } |     bool can_accept() const { return !m_pending.is_empty(); } | ||||||
|     RefPtr<Socket> accept(); |     RefPtr<Socket> accept(); | ||||||
|     bool is_connected() const { return m_connected; } |     bool is_connected() const { return m_connected; } | ||||||
|     KResult listen(int backlog); |  | ||||||
| 
 | 
 | ||||||
|     virtual KResult bind(const sockaddr*, socklen_t) = 0; |     virtual KResult bind(const sockaddr*, socklen_t) = 0; | ||||||
|     virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0; |     virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0; | ||||||
|  |     virtual KResult listen(int) = 0; | ||||||
|     virtual bool get_local_address(sockaddr*, socklen_t*) = 0; |     virtual bool get_local_address(sockaddr*, socklen_t*) = 0; | ||||||
|     virtual bool get_peer_address(sockaddr*, socklen_t*) = 0; |     virtual bool get_peer_address(sockaddr*, socklen_t*) = 0; | ||||||
|     virtual bool is_local() const { return false; } |     virtual bool is_local() const { return false; } | ||||||
|  | @ -73,6 +73,9 @@ protected: | ||||||
|     void load_receive_deadline(); |     void load_receive_deadline(); | ||||||
|     void load_send_deadline(); |     void load_send_deadline(); | ||||||
| 
 | 
 | ||||||
|  |     int backlog() const { return m_backlog; } | ||||||
|  |     void set_backlog(int backlog) { m_backlog = backlog; } | ||||||
|  | 
 | ||||||
|     virtual const char* class_name() const override { return "Socket"; } |     virtual const char* class_name() const override { return "Socket"; } | ||||||
| 
 | 
 | ||||||
| private: | private: | ||||||
|  |  | ||||||
|  | @ -39,6 +39,7 @@ public: | ||||||
|     bool has_syn() const { return flags() & TCPFlags::SYN; } |     bool has_syn() const { return flags() & TCPFlags::SYN; } | ||||||
|     bool has_ack() const { return flags() & TCPFlags::ACK; } |     bool has_ack() const { return flags() & TCPFlags::ACK; } | ||||||
|     bool has_fin() const { return flags() & TCPFlags::FIN; } |     bool has_fin() const { return flags() & TCPFlags::FIN; } | ||||||
|  |     bool has_rst() const { return flags() & TCPFlags::RST; } | ||||||
| 
 | 
 | ||||||
|     u8 data_offset() const { return (m_flags_and_data_offset & 0xf000) >> 12; } |     u8 data_offset() const { return (m_flags_and_data_offset & 0xf000) >> 12; } | ||||||
|     void set_data_offset(u16 data_offset) { m_flags_and_data_offset = (m_flags_and_data_offset & ~0xf000) | data_offset << 12; } |     void set_data_offset(u16 data_offset) { m_flags_and_data_offset = (m_flags_and_data_offset & ~0xf000) | data_offset << 12; } | ||||||
|  |  | ||||||
|  | @ -1,28 +1,35 @@ | ||||||
| #include <Kernel/Devices/RandomDevice.h> | #include <Kernel/Devices/RandomDevice.h> | ||||||
|  | #include <Kernel/FileSystem/FileDescription.h> | ||||||
| #include <Kernel/Net/NetworkAdapter.h> | #include <Kernel/Net/NetworkAdapter.h> | ||||||
| #include <Kernel/Net/Routing.h> | #include <Kernel/Net/Routing.h> | ||||||
| #include <Kernel/Net/TCP.h> | #include <Kernel/Net/TCP.h> | ||||||
| #include <Kernel/Net/TCPSocket.h> | #include <Kernel/Net/TCPSocket.h> | ||||||
| #include <Kernel/FileSystem/FileDescription.h> |  | ||||||
| #include <Kernel/Process.h> | #include <Kernel/Process.h> | ||||||
| 
 | 
 | ||||||
| //#define TCP_SOCKET_DEBUG
 | //#define TCP_SOCKET_DEBUG
 | ||||||
| 
 | 
 | ||||||
| Lockable<HashMap<u16, TCPSocket*>>& TCPSocket::sockets_by_port() | void TCPSocket::for_each(Function<void(TCPSocket*&)> callback) | ||||||
| { | { | ||||||
|     static Lockable<HashMap<u16, TCPSocket*>>* s_map; |     LOCKER(sockets_by_tuple().lock()); | ||||||
|  |     for (auto& it : sockets_by_tuple().resource()) | ||||||
|  |         callback(it.value); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& TCPSocket::sockets_by_tuple() | ||||||
|  | { | ||||||
|  |     static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>* s_map; | ||||||
|     if (!s_map) |     if (!s_map) | ||||||
|         s_map = new Lockable<HashMap<u16, TCPSocket*>>; |         s_map = new Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>; | ||||||
|     return *s_map; |     return *s_map; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TCPSocketHandle TCPSocket::from_port(u16 port) | TCPSocketHandle TCPSocket::from_tuple(const IPv4SocketTuple& tuple) | ||||||
| { | { | ||||||
|     RefPtr<TCPSocket> socket; |     RefPtr<TCPSocket> socket; | ||||||
|     { |     { | ||||||
|         LOCKER(sockets_by_port().lock()); |         LOCKER(sockets_by_tuple().lock()); | ||||||
|         auto it = sockets_by_port().resource().find(port); |         auto it = sockets_by_tuple().resource().find(tuple); | ||||||
|         if (it == sockets_by_port().resource().end()) |         if (it == sockets_by_tuple().resource().end()) | ||||||
|             return {}; |             return {}; | ||||||
|         socket = (*it).value; |         socket = (*it).value; | ||||||
|         ASSERT(socket); |         ASSERT(socket); | ||||||
|  | @ -30,6 +37,11 @@ TCPSocketHandle TCPSocket::from_port(u16 port) | ||||||
|     return { move(socket) }; |     return { move(socket) }; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TCPSocketHandle TCPSocket::from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port) | ||||||
|  | { | ||||||
|  |     return from_tuple(IPv4SocketTuple(local_address, local_port, peer_address, peer_port)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| TCPSocket::TCPSocket(int protocol) | TCPSocket::TCPSocket(int protocol) | ||||||
|     : IPv4Socket(SOCK_STREAM, protocol) |     : IPv4Socket(SOCK_STREAM, protocol) | ||||||
| { | { | ||||||
|  | @ -37,8 +49,8 @@ TCPSocket::TCPSocket(int protocol) | ||||||
| 
 | 
 | ||||||
| TCPSocket::~TCPSocket() | TCPSocket::~TCPSocket() | ||||||
| { | { | ||||||
|     LOCKER(sockets_by_port().lock()); |     LOCKER(sockets_by_tuple().lock()); | ||||||
|     sockets_by_port().resource().remove(local_port()); |     sockets_by_tuple().resource().remove(tuple()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| NonnullRefPtr<TCPSocket> TCPSocket::create(int protocol) | NonnullRefPtr<TCPSocket> TCPSocket::create(int protocol) | ||||||
|  | @ -62,18 +74,13 @@ int TCPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size | ||||||
| 
 | 
 | ||||||
| int TCPSocket::protocol_send(const void* data, int data_length) | int TCPSocket::protocol_send(const void* data, int data_length) | ||||||
| { | { | ||||||
|     auto* adapter = adapter_for_route_to(peer_address()); |  | ||||||
|     if (!adapter) |  | ||||||
|         return -EHOSTUNREACH; |  | ||||||
|     send_tcp_packet(TCPFlags::PUSH | TCPFlags::ACK, data, data_length); |     send_tcp_packet(TCPFlags::PUSH | TCPFlags::ACK, data, data_length); | ||||||
|     return data_length; |     return data_length; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size) | void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size) | ||||||
| { | { | ||||||
|     // FIXME: Maybe the socket should be bound to an adapter instead of looking it up every time?
 |     ASSERT(m_adapter); | ||||||
|     auto* adapter = adapter_for_route_to(peer_address()); |  | ||||||
|     ASSERT(adapter); |  | ||||||
| 
 | 
 | ||||||
|     auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size); |     auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size); | ||||||
|     auto& tcp_packet = *(TCPPacket*)(buffer.pointer()); |     auto& tcp_packet = *(TCPPacket*)(buffer.pointer()); | ||||||
|  | @ -95,19 +102,21 @@ void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     memcpy(tcp_packet.payload(), payload, payload_size); |     memcpy(tcp_packet.payload(), payload, payload_size); | ||||||
|     tcp_packet.set_checksum(compute_tcp_checksum(adapter->ipv4_address(), peer_address(), tcp_packet, payload_size)); |     tcp_packet.set_checksum(compute_tcp_checksum(local_address(), peer_address(), tcp_packet, payload_size)); | ||||||
| #ifdef TCP_SOCKET_DEBUG | #ifdef TCP_SOCKET_DEBUG | ||||||
|     kprintf("sending tcp packet from %s:%u to %s:%u with (%s %s) seq_no=%u, ack_no=%u\n", |     kprintf("sending tcp packet from %s:%u to %s:%u with (%s%s%s%s) seq_no=%u, ack_no=%u\n", | ||||||
|         adapter->ipv4_address().to_string().characters(), |         local_address().to_string().characters(), | ||||||
|         local_port(), |         local_port(), | ||||||
|         peer_address().to_string().characters(), |         peer_address().to_string().characters(), | ||||||
|         peer_port(), |         peer_port(), | ||||||
|         tcp_packet.has_syn() ? "SYN" : "", |         tcp_packet.has_syn() ? "SYN" : "", | ||||||
|         tcp_packet.has_ack() ? "ACK" : "", |         tcp_packet.has_ack() ? "ACK" : "", | ||||||
|  |         tcp_packet.has_fin() ? "FIN" : "", | ||||||
|  |         tcp_packet.has_rst() ? "RST" : "", | ||||||
|         tcp_packet.sequence_number(), |         tcp_packet.sequence_number(), | ||||||
|         tcp_packet.ack_number()); |         tcp_packet.ack_number()); | ||||||
| #endif | #endif | ||||||
|     adapter->send_ipv4(MACAddress(), peer_address(), IPv4Protocol::TCP, buffer.data(), buffer.size()); |     m_adapter->send_ipv4(MACAddress(), peer_address(), IPv4Protocol::TCP, buffer.data(), buffer.size()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket& packet, u16 payload_size) | NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket& packet, u16 payload_size) | ||||||
|  | @ -152,11 +161,36 @@ NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(const IPv4Address& source, c | ||||||
|     return ~(checksum & 0xffff); |     return ~(checksum & 0xffff); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | KResult TCPSocket::protocol_bind() | ||||||
|  | { | ||||||
|  |     if (!m_adapter) { | ||||||
|  |         m_adapter = NetworkAdapter::from_ipv4_address(local_address()); | ||||||
|  |         if (!m_adapter) | ||||||
|  |             return KResult(-EADDRNOTAVAIL); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     return KSuccess; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | KResult TCPSocket::protocol_listen() | ||||||
|  | { | ||||||
|  |     LOCKER(sockets_by_tuple().lock()); | ||||||
|  |     if (sockets_by_tuple().resource().contains(tuple())) | ||||||
|  |         return KResult(-EADDRINUSE); | ||||||
|  |     sockets_by_tuple().resource().set(tuple(), this); | ||||||
|  |     set_state(State::Listen); | ||||||
|  |     return KSuccess; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock should_block) | KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock should_block) | ||||||
| { | { | ||||||
|     auto* adapter = adapter_for_route_to(peer_address()); |     if (!m_adapter) { | ||||||
|     if (!adapter) |         m_adapter = adapter_for_route_to(peer_address()); | ||||||
|         return KResult(-EHOSTUNREACH); |         if (!m_adapter) | ||||||
|  |             return KResult(-EHOSTUNREACH); | ||||||
|  | 
 | ||||||
|  |         set_local_address(m_adapter->ipv4_address()); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     allocate_local_port_if_needed(); |     allocate_local_port_if_needed(); | ||||||
| 
 | 
 | ||||||
|  | @ -164,7 +198,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh | ||||||
|     m_ack_number = 0; |     m_ack_number = 0; | ||||||
| 
 | 
 | ||||||
|     send_tcp_packet(TCPFlags::SYN); |     send_tcp_packet(TCPFlags::SYN); | ||||||
|     m_state = State::Connecting; |     m_state = State::SynSent; | ||||||
| 
 | 
 | ||||||
|     if (should_block == ShouldBlock::Yes) { |     if (should_block == ShouldBlock::Yes) { | ||||||
|         if (current->block<Thread::ConnectBlocker>(description) == Thread::BlockResult::InterruptedBySignal) |         if (current->block<Thread::ConnectBlocker>(description) == Thread::BlockResult::InterruptedBySignal) | ||||||
|  | @ -183,12 +217,14 @@ int TCPSocket::protocol_allocate_local_port() | ||||||
|     static const u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; |     static const u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; | ||||||
|     u16 first_scan_port = first_ephemeral_port + RandomDevice::random_value() % ephemeral_port_range_size; |     u16 first_scan_port = first_ephemeral_port + RandomDevice::random_value() % ephemeral_port_range_size; | ||||||
| 
 | 
 | ||||||
|     LOCKER(sockets_by_port().lock()); |     LOCKER(sockets_by_tuple().lock()); | ||||||
|     for (u16 port = first_scan_port;;) { |     for (u16 port = first_scan_port;;) { | ||||||
|         auto it = sockets_by_port().resource().find(port); |         IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); | ||||||
|         if (it == sockets_by_port().resource().end()) { | 
 | ||||||
|  |         auto it = sockets_by_tuple().resource().find(proposed_tuple); | ||||||
|  |         if (it == sockets_by_tuple().resource().end()) { | ||||||
|             set_local_port(port); |             set_local_port(port); | ||||||
|             sockets_by_port().resource().set(port, this); |             sockets_by_tuple().resource().set(proposed_tuple, this); | ||||||
|             return port; |             return port; | ||||||
|         } |         } | ||||||
|         ++port; |         ++port; | ||||||
|  | @ -202,14 +238,16 @@ int TCPSocket::protocol_allocate_local_port() | ||||||
| 
 | 
 | ||||||
| bool TCPSocket::protocol_is_disconnected() const | bool TCPSocket::protocol_is_disconnected() const | ||||||
| { | { | ||||||
|     return m_state == State::Disconnecting || m_state == State::Disconnected; |     switch (m_state) { | ||||||
| } |     case State::Closed: | ||||||
| 
 |     case State::CloseWait: | ||||||
| KResult TCPSocket::protocol_bind() |     case State::LastAck: | ||||||
| { |     case State::FinWait1: | ||||||
|     LOCKER(sockets_by_port().lock()); |     case State::FinWait2: | ||||||
|     if (sockets_by_port().resource().contains(local_port())) |     case State::Closing: | ||||||
|         return KResult(-EADDRINUSE); |     case State::TimeWait: | ||||||
|     sockets_by_port().resource().set(local_port(), this); |         return true; | ||||||
|     return KSuccess; |     default: | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,19 +1,58 @@ | ||||||
| #pragma once | #pragma once | ||||||
| 
 | 
 | ||||||
|  | #include <AK/Function.h> | ||||||
| #include <Kernel/Net/IPv4Socket.h> | #include <Kernel/Net/IPv4Socket.h> | ||||||
| 
 | 
 | ||||||
| class TCPSocket final : public IPv4Socket { | class TCPSocket final : public IPv4Socket { | ||||||
| public: | public: | ||||||
|  |     static void for_each(Function<void(TCPSocket*&)>); | ||||||
|     static NonnullRefPtr<TCPSocket> create(int protocol); |     static NonnullRefPtr<TCPSocket> create(int protocol); | ||||||
|     virtual ~TCPSocket() override; |     virtual ~TCPSocket() override; | ||||||
| 
 | 
 | ||||||
|     enum class State { |     enum class State { | ||||||
|         Disconnected, |         Closed, | ||||||
|         Connecting, |         Listen, | ||||||
|         Connected, |         SynSent, | ||||||
|         Disconnecting, |         SynReceived, | ||||||
|  |         Established, | ||||||
|  |         CloseWait, | ||||||
|  |         LastAck, | ||||||
|  |         FinWait1, | ||||||
|  |         FinWait2, | ||||||
|  |         Closing, | ||||||
|  |         TimeWait, | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|  |     static const char* to_string(State state) | ||||||
|  |     { | ||||||
|  |         switch (state) { | ||||||
|  |         case State::Closed: | ||||||
|  |             return "Closed"; | ||||||
|  |         case State::Listen: | ||||||
|  |             return "Listen"; | ||||||
|  |         case State::SynSent: | ||||||
|  |             return "SynSent"; | ||||||
|  |         case State::SynReceived: | ||||||
|  |             return "SynReceived"; | ||||||
|  |         case State::Established: | ||||||
|  |             return "Established"; | ||||||
|  |         case State::CloseWait: | ||||||
|  |             return "CloseWait"; | ||||||
|  |         case State::LastAck: | ||||||
|  |             return "LastAck"; | ||||||
|  |         case State::FinWait1: | ||||||
|  |             return "FinWait1"; | ||||||
|  |         case State::FinWait2: | ||||||
|  |             return "FinWait2"; | ||||||
|  |         case State::Closing: | ||||||
|  |             return "Closing"; | ||||||
|  |         case State::TimeWait: | ||||||
|  |             return "TimeWait"; | ||||||
|  |         default: | ||||||
|  |             return "None"; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     State state() const { return m_state; } |     State state() const { return m_state; } | ||||||
|     void set_state(State state) { m_state = state; } |     void set_state(State state) { m_state = state; } | ||||||
| 
 | 
 | ||||||
|  | @ -24,8 +63,9 @@ public: | ||||||
| 
 | 
 | ||||||
|     void send_tcp_packet(u16 flags, const void* = nullptr, int = 0); |     void send_tcp_packet(u16 flags, const void* = nullptr, int = 0); | ||||||
| 
 | 
 | ||||||
|     static Lockable<HashMap<u16, TCPSocket*>>& sockets_by_port(); |     static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& sockets_by_tuple(); | ||||||
|     static TCPSocketHandle from_port(u16); |     static TCPSocketHandle from_tuple(const IPv4SocketTuple& tuple); | ||||||
|  |     static TCPSocketHandle from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port); | ||||||
| 
 | 
 | ||||||
| private: | private: | ||||||
|     explicit TCPSocket(int protocol); |     explicit TCPSocket(int protocol); | ||||||
|  | @ -39,10 +79,12 @@ private: | ||||||
|     virtual int protocol_allocate_local_port() override; |     virtual int protocol_allocate_local_port() override; | ||||||
|     virtual bool protocol_is_disconnected() const override; |     virtual bool protocol_is_disconnected() const override; | ||||||
|     virtual KResult protocol_bind() override; |     virtual KResult protocol_bind() override; | ||||||
|  |     virtual KResult protocol_listen() override; | ||||||
| 
 | 
 | ||||||
|  |     NetworkAdapter* m_adapter { nullptr }; | ||||||
|     u32 m_sequence_number { 0 }; |     u32 m_sequence_number { 0 }; | ||||||
|     u32 m_ack_number { 0 }; |     u32 m_ack_number { 0 }; | ||||||
|     State m_state { State::Disconnected }; |     State m_state { State::Closed }; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| class TCPSocketHandle : public SocketHandle { | class TCPSocketHandle : public SocketHandle { | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Conrad Pankoff
						Conrad Pankoff