mirror of
https://github.com/RGBCube/serenity
synced 2025-05-31 02:58:12 +00:00
Kernel: Add SocketHandle helper class that wraps locked sockets.
This allows us to have a comfy IPv4Socket::from_tcp_port() API that returns a socket that's locked and safe to access. No need to worry about locking at the client site.
This commit is contained in:
parent
3d5296a901
commit
54e7df0586
4 changed files with 106 additions and 22 deletions
|
@ -27,6 +27,34 @@ Lockable<HashMap<word, IPv4Socket*>>& IPv4Socket::sockets_by_tcp_port()
|
|||
return *s_map;
|
||||
}
|
||||
|
||||
IPv4SocketHandle IPv4Socket::from_tcp_port(word port)
|
||||
{
|
||||
RetainPtr<IPv4Socket> socket;
|
||||
{
|
||||
LOCKER(sockets_by_tcp_port().lock());
|
||||
auto it = sockets_by_tcp_port().resource().find(port);
|
||||
if (it == sockets_by_tcp_port().resource().end())
|
||||
return { };
|
||||
socket = (*it).value;
|
||||
ASSERT(socket);
|
||||
}
|
||||
return { move(socket) };
|
||||
}
|
||||
|
||||
IPv4SocketHandle IPv4Socket::from_udp_port(word port)
|
||||
{
|
||||
RetainPtr<IPv4Socket> socket;
|
||||
{
|
||||
LOCKER(sockets_by_udp_port().lock());
|
||||
auto it = sockets_by_udp_port().resource().find(port);
|
||||
if (it == sockets_by_udp_port().resource().end())
|
||||
return { };
|
||||
socket = (*it).value;
|
||||
ASSERT(socket);
|
||||
}
|
||||
return { move(socket) };
|
||||
}
|
||||
|
||||
Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets()
|
||||
{
|
||||
static Lockable<HashTable<IPv4Socket*>>* s_table;
|
||||
|
@ -217,8 +245,12 @@ NetworkOrdered<word> IPv4Socket::compute_tcp_checksum(const IPv4Address& source,
|
|||
if (checksum > 0xffff)
|
||||
checksum = (checksum >> 16) + (checksum & 0xffff);
|
||||
}
|
||||
if (payload_size & 1)
|
||||
ASSERT_NOT_REACHED();
|
||||
if (payload_size & 1) {
|
||||
word expanded_byte = ((const byte*)packet.payload())[payload_size - 1];
|
||||
checksum += expanded_byte;
|
||||
if (checksum > 0xffff)
|
||||
checksum = (checksum >> 16) + (checksum & 0xffff);
|
||||
}
|
||||
return ~(checksum & 0xffff);
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include <AK/Lock.h>
|
||||
#include <AK/SinglyLinkedList.h>
|
||||
|
||||
class IPv4SocketHandle;
|
||||
class NetworkAdapter;
|
||||
class TCPPacket;
|
||||
|
||||
|
@ -28,6 +29,9 @@ public:
|
|||
static Lockable<HashMap<word, IPv4Socket*>>& sockets_by_udp_port();
|
||||
static Lockable<HashMap<word, IPv4Socket*>>& sockets_by_tcp_port();
|
||||
|
||||
static IPv4SocketHandle from_tcp_port(word);
|
||||
static IPv4SocketHandle from_udp_port(word);
|
||||
|
||||
virtual KResult bind(const sockaddr*, socklen_t) override;
|
||||
virtual KResult connect(const sockaddr*, socklen_t) override;
|
||||
virtual bool get_address(sockaddr*, socklen_t*) override;
|
||||
|
@ -79,3 +83,26 @@ private:
|
|||
bool m_can_read { false };
|
||||
};
|
||||
|
||||
class IPv4SocketHandle : public SocketHandle {
|
||||
public:
|
||||
IPv4SocketHandle() { }
|
||||
|
||||
IPv4SocketHandle(RetainPtr<IPv4Socket>&& socket)
|
||||
: SocketHandle(move(socket))
|
||||
{
|
||||
}
|
||||
|
||||
IPv4SocketHandle(IPv4SocketHandle&& other)
|
||||
: SocketHandle(move(other))
|
||||
{
|
||||
}
|
||||
|
||||
IPv4SocketHandle(const IPv4SocketHandle&) = delete;
|
||||
IPv4SocketHandle& operator=(const IPv4SocketHandle&) = delete;
|
||||
|
||||
IPv4Socket* operator->() { return &socket(); }
|
||||
const IPv4Socket* operator->() const { return &socket(); }
|
||||
|
||||
IPv4Socket& socket() { return static_cast<IPv4Socket&>(SocketHandle::socket()); }
|
||||
const IPv4Socket& socket() const { return static_cast<const IPv4Socket&>(SocketHandle::socket()); }
|
||||
};
|
||||
|
|
|
@ -234,17 +234,12 @@ void handle_udp(const EthernetFrameHeader& eth, int frame_size)
|
|||
);
|
||||
#endif
|
||||
|
||||
RetainPtr<IPv4Socket> socket;
|
||||
{
|
||||
LOCKER(IPv4Socket::sockets_by_udp_port().lock());
|
||||
auto it = IPv4Socket::sockets_by_udp_port().resource().find(udp_packet.destination_port());
|
||||
if (it == IPv4Socket::sockets_by_udp_port().resource().end())
|
||||
auto socket = IPv4Socket::from_udp_port(udp_packet.destination_port());
|
||||
if (!socket) {
|
||||
kprintf("handle_udp: No UDP socket for port %u\n", udp_packet.destination_port());
|
||||
return;
|
||||
ASSERT((*it).value);
|
||||
socket = *(*it).value;
|
||||
}
|
||||
|
||||
LOCKER(socket->lock());
|
||||
ASSERT(socket->type() == SOCK_DGRAM);
|
||||
ASSERT(socket->source_port() == udp_packet.destination_port());
|
||||
socket->did_receive(ByteBuffer::copy((const byte*)&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
|
||||
|
@ -280,19 +275,12 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
|
|||
);
|
||||
#endif
|
||||
|
||||
RetainPtr<IPv4Socket> socket;
|
||||
{
|
||||
LOCKER(IPv4Socket::sockets_by_tcp_port().lock());
|
||||
auto it = IPv4Socket::sockets_by_tcp_port().resource().find(tcp_packet.destination_port());
|
||||
if (it == IPv4Socket::sockets_by_tcp_port().resource().end()) {
|
||||
auto socket = IPv4Socket::from_tcp_port(tcp_packet.destination_port());
|
||||
if (!socket) {
|
||||
kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port());
|
||||
return;
|
||||
}
|
||||
ASSERT((*it).value);
|
||||
socket = *(*it).value;
|
||||
}
|
||||
|
||||
LOCKER(socket->lock());
|
||||
ASSERT(socket->type() == SOCK_STREAM);
|
||||
ASSERT(socket->source_port() == tcp_packet.destination_port());
|
||||
|
||||
|
|
|
@ -76,3 +76,40 @@ private:
|
|||
Vector<RetainPtr<Socket>> m_pending;
|
||||
Vector<RetainPtr<Socket>> m_clients;
|
||||
};
|
||||
|
||||
class SocketHandle {
|
||||
public:
|
||||
SocketHandle() { }
|
||||
|
||||
SocketHandle(RetainPtr<Socket>&& socket)
|
||||
: m_socket(move(socket))
|
||||
{
|
||||
if (m_socket)
|
||||
m_socket->lock().lock();
|
||||
}
|
||||
|
||||
SocketHandle(SocketHandle&& other)
|
||||
: m_socket(move(other.m_socket))
|
||||
{
|
||||
}
|
||||
|
||||
~SocketHandle()
|
||||
{
|
||||
if (m_socket)
|
||||
m_socket->lock().unlock();
|
||||
}
|
||||
|
||||
SocketHandle(const SocketHandle&) = delete;
|
||||
SocketHandle& operator=(const SocketHandle&) = delete;
|
||||
|
||||
operator bool() const { return m_socket; }
|
||||
|
||||
Socket* operator->() { return &socket(); }
|
||||
const Socket* operator->() const { return &socket(); }
|
||||
|
||||
Socket& socket() { return *m_socket; }
|
||||
const Socket& socket() const { return *m_socket; }
|
||||
|
||||
private:
|
||||
RetainPtr<Socket> m_socket;
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue