From 54e7df0586b06fd2bf56f8e986d4081138ba295e Mon Sep 17 00:00:00 2001 From: Andreas Kling Date: Thu, 14 Mar 2019 09:19:24 +0100 Subject: [PATCH] Kernel: Add SocketHandle helper class that wraps locked sockets. This allows us to have a comfy IPv4Socket::from_tcp_port() API that returns a socket that's locked and safe to access. No need to worry about locking at the client site. --- Kernel/IPv4Socket.cpp | 36 ++++++++++++++++++++++++++++++++++-- Kernel/IPv4Socket.h | 27 +++++++++++++++++++++++++++ Kernel/NetworkTask.cpp | 28 ++++++++-------------------- Kernel/Socket.h | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 106 insertions(+), 22 deletions(-) diff --git a/Kernel/IPv4Socket.cpp b/Kernel/IPv4Socket.cpp index 8a090a35fc..d300c772db 100644 --- a/Kernel/IPv4Socket.cpp +++ b/Kernel/IPv4Socket.cpp @@ -27,6 +27,34 @@ Lockable>& IPv4Socket::sockets_by_tcp_port() return *s_map; } +IPv4SocketHandle IPv4Socket::from_tcp_port(word port) +{ + RetainPtr socket; + { + LOCKER(sockets_by_tcp_port().lock()); + auto it = sockets_by_tcp_port().resource().find(port); + if (it == sockets_by_tcp_port().resource().end()) + return { }; + socket = (*it).value; + ASSERT(socket); + } + return { move(socket) }; +} + +IPv4SocketHandle IPv4Socket::from_udp_port(word port) +{ + RetainPtr socket; + { + LOCKER(sockets_by_udp_port().lock()); + auto it = sockets_by_udp_port().resource().find(port); + if (it == sockets_by_udp_port().resource().end()) + return { }; + socket = (*it).value; + ASSERT(socket); + } + return { move(socket) }; +} + Lockable>& IPv4Socket::all_sockets() { static Lockable>* s_table; @@ -217,8 +245,12 @@ NetworkOrdered IPv4Socket::compute_tcp_checksum(const IPv4Address& source, if (checksum > 0xffff) checksum = (checksum >> 16) + (checksum & 0xffff); } - if (payload_size & 1) - ASSERT_NOT_REACHED(); + if (payload_size & 1) { + word expanded_byte = ((const byte*)packet.payload())[payload_size - 1]; + checksum += expanded_byte; + if (checksum > 0xffff) + checksum = (checksum >> 16) + (checksum & 0xffff); + } return ~(checksum & 0xffff); } diff --git a/Kernel/IPv4Socket.h b/Kernel/IPv4Socket.h index 6d948c78ed..2c84eabf9c 100644 --- a/Kernel/IPv4Socket.h +++ b/Kernel/IPv4Socket.h @@ -7,6 +7,7 @@ #include #include +class IPv4SocketHandle; class NetworkAdapter; class TCPPacket; @@ -28,6 +29,9 @@ public: static Lockable>& sockets_by_udp_port(); static Lockable>& sockets_by_tcp_port(); + static IPv4SocketHandle from_tcp_port(word); + static IPv4SocketHandle from_udp_port(word); + virtual KResult bind(const sockaddr*, socklen_t) override; virtual KResult connect(const sockaddr*, socklen_t) override; virtual bool get_address(sockaddr*, socklen_t*) override; @@ -79,3 +83,26 @@ private: bool m_can_read { false }; }; +class IPv4SocketHandle : public SocketHandle { +public: + IPv4SocketHandle() { } + + IPv4SocketHandle(RetainPtr&& socket) + : SocketHandle(move(socket)) + { + } + + IPv4SocketHandle(IPv4SocketHandle&& other) + : SocketHandle(move(other)) + { + } + + IPv4SocketHandle(const IPv4SocketHandle&) = delete; + IPv4SocketHandle& operator=(const IPv4SocketHandle&) = delete; + + IPv4Socket* operator->() { return &socket(); } + const IPv4Socket* operator->() const { return &socket(); } + + IPv4Socket& socket() { return static_cast(SocketHandle::socket()); } + const IPv4Socket& socket() const { return static_cast(SocketHandle::socket()); } +}; diff --git a/Kernel/NetworkTask.cpp b/Kernel/NetworkTask.cpp index 03e2f21a88..61b4464b29 100644 --- a/Kernel/NetworkTask.cpp +++ b/Kernel/NetworkTask.cpp @@ -234,17 +234,12 @@ void handle_udp(const EthernetFrameHeader& eth, int frame_size) ); #endif - RetainPtr socket; - { - LOCKER(IPv4Socket::sockets_by_udp_port().lock()); - auto it = IPv4Socket::sockets_by_udp_port().resource().find(udp_packet.destination_port()); - if (it == IPv4Socket::sockets_by_udp_port().resource().end()) - return; - ASSERT((*it).value); - socket = *(*it).value; + auto socket = IPv4Socket::from_udp_port(udp_packet.destination_port()); + if (!socket) { + kprintf("handle_udp: No UDP socket for port %u\n", udp_packet.destination_port()); + return; } - LOCKER(socket->lock()); ASSERT(socket->type() == SOCK_DGRAM); ASSERT(socket->source_port() == udp_packet.destination_port()); socket->did_receive(ByteBuffer::copy((const byte*)&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); @@ -280,19 +275,12 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) ); #endif - RetainPtr socket; - { - LOCKER(IPv4Socket::sockets_by_tcp_port().lock()); - auto it = IPv4Socket::sockets_by_tcp_port().resource().find(tcp_packet.destination_port()); - if (it == IPv4Socket::sockets_by_tcp_port().resource().end()) { - kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port()); - return; - } - ASSERT((*it).value); - socket = *(*it).value; + auto socket = IPv4Socket::from_tcp_port(tcp_packet.destination_port()); + if (!socket) { + kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port()); + return; } - LOCKER(socket->lock()); ASSERT(socket->type() == SOCK_STREAM); ASSERT(socket->source_port() == tcp_packet.destination_port()); diff --git a/Kernel/Socket.h b/Kernel/Socket.h index 127394ae2f..f741df270b 100644 --- a/Kernel/Socket.h +++ b/Kernel/Socket.h @@ -76,3 +76,40 @@ private: Vector> m_pending; Vector> m_clients; }; + +class SocketHandle { +public: + SocketHandle() { } + + SocketHandle(RetainPtr&& socket) + : m_socket(move(socket)) + { + if (m_socket) + m_socket->lock().lock(); + } + + SocketHandle(SocketHandle&& other) + : m_socket(move(other.m_socket)) + { + } + + ~SocketHandle() + { + if (m_socket) + m_socket->lock().unlock(); + } + + SocketHandle(const SocketHandle&) = delete; + SocketHandle& operator=(const SocketHandle&) = delete; + + operator bool() const { return m_socket; } + + Socket* operator->() { return &socket(); } + const Socket* operator->() const { return &socket(); } + + Socket& socket() { return *m_socket; } + const Socket& socket() const { return *m_socket; } + +private: + RetainPtr m_socket; +};