1
Fork 0
mirror of https://github.com/RGBCube/serenity synced 2025-05-29 14:05:11 +00:00

Kernel: Give each FileDescriptor a chance to co-open sockets.

Track how many fds are open for a socket's Accepted and Connected roles.
This allows fork() to clone a socket fd without a subsequent close() walking
all over the parent process's fd.
This commit is contained in:
Andreas Kling 2019-02-17 11:00:35 +01:00
parent b0be3299b5
commit d5f515cf6c
6 changed files with 51 additions and 27 deletions

View file

@ -49,14 +49,14 @@ FileDescriptor::FileDescriptor(RetainPtr<Device>&& device)
FileDescriptor::FileDescriptor(RetainPtr<Socket>&& socket, SocketRole role) FileDescriptor::FileDescriptor(RetainPtr<Socket>&& socket, SocketRole role)
: m_socket(move(socket)) : m_socket(move(socket))
, m_socket_role(role)
{ {
set_socket_role(role);
} }
FileDescriptor::~FileDescriptor() FileDescriptor::~FileDescriptor()
{ {
if (m_socket) { if (m_socket) {
m_socket->close(m_socket_role); m_socket->detach_fd(m_socket_role);
m_socket = nullptr; m_socket = nullptr;
} }
if (m_device) { if (m_device) {
@ -70,6 +70,16 @@ FileDescriptor::~FileDescriptor()
m_inode = nullptr; m_inode = nullptr;
} }
void FileDescriptor::set_socket_role(SocketRole role)
{
if (role == m_socket_role)
return;
ASSERT(m_socket);
m_socket_role = role;
m_socket->attach_fd(role);
}
RetainPtr<FileDescriptor> FileDescriptor::clone() RetainPtr<FileDescriptor> FileDescriptor::clone()
{ {
RetainPtr<FileDescriptor> descriptor; RetainPtr<FileDescriptor> descriptor;
@ -81,6 +91,9 @@ RetainPtr<FileDescriptor> FileDescriptor::clone()
if (m_device) { if (m_device) {
descriptor = FileDescriptor::create(m_device.copy_ref()); descriptor = FileDescriptor::create(m_device.copy_ref());
descriptor->m_inode = m_inode.copy_ref(); descriptor->m_inode = m_inode.copy_ref();
} else if (m_socket) {
descriptor = FileDescriptor::create(m_socket.copy_ref(), m_socket_role);
descriptor->m_inode = m_inode.copy_ref();
} else { } else {
descriptor = FileDescriptor::create(m_inode.copy_ref()); descriptor = FileDescriptor::create(m_inode.copy_ref());
} }

View file

@ -86,7 +86,7 @@ public:
void set_original_inode(Badge<VFS>, RetainPtr<Inode>&& inode) { m_inode = move(inode); } void set_original_inode(Badge<VFS>, RetainPtr<Inode>&& inode) { m_inode = move(inode); }
void set_socket_role(SocketRole role) { m_socket_role = role; } void set_socket_role(SocketRole);
private: private:
friend class VFS; friend class VFS;

View file

@ -106,23 +106,34 @@ bool LocalSocket::connect(const sockaddr* address, socklen_t address_size, int&
return true; return true;
} }
void LocalSocket::close(SocketRole role) void LocalSocket::attach_fd(SocketRole role)
{ {
if (role == SocketRole::Accepted) if (role == SocketRole::Accepted) {
m_server_closed = true; ++m_accepted_fds_open;
else if (role == SocketRole::Connected) } else if (role == SocketRole::Connected) {
m_client_closed = true; ++m_connected_fds_open;
}
}
void LocalSocket::detach_fd(SocketRole role)
{
if (role == SocketRole::Accepted) {
ASSERT(m_accepted_fds_open);
--m_accepted_fds_open;
} else if (role == SocketRole::Connected) {
ASSERT(m_connected_fds_open);
--m_connected_fds_open;
}
} }
bool LocalSocket::can_read(SocketRole role) const bool LocalSocket::can_read(SocketRole role) const
{ {
if (m_bound && is_listening()) if (role == SocketRole::Listener)
return can_accept(); return can_accept();
if (role == SocketRole::Accepted) if (role == SocketRole::Accepted)
return m_client_closed || !m_for_server.is_empty(); return !m_connected_fds_open || !m_for_server.is_empty();
else if (role == SocketRole::Connected) if (role == SocketRole::Connected)
return m_server_closed || !m_for_client.is_empty(); return !m_accepted_fds_open || !m_for_client.is_empty();
ASSERT_NOT_REACHED(); ASSERT_NOT_REACHED();
} }
@ -130,7 +141,7 @@ ssize_t LocalSocket::read(SocketRole role, byte* buffer, size_t size)
{ {
if (role == SocketRole::Accepted) if (role == SocketRole::Accepted)
return m_for_server.read(buffer, size); return m_for_server.read(buffer, size);
else if (role == SocketRole::Connected) if (role == SocketRole::Connected)
return m_for_client.read(buffer, size); return m_for_client.read(buffer, size);
ASSERT_NOT_REACHED(); ASSERT_NOT_REACHED();
} }
@ -138,11 +149,12 @@ ssize_t LocalSocket::read(SocketRole role, byte* buffer, size_t size)
ssize_t LocalSocket::write(SocketRole role, const byte* data, size_t size) ssize_t LocalSocket::write(SocketRole role, const byte* data, size_t size)
{ {
if (role == SocketRole::Accepted) { if (role == SocketRole::Accepted) {
if (m_client_closed) if (!m_accepted_fds_open)
return -EPIPE; return -EPIPE;
return m_for_client.write(data, size); return m_for_client.write(data, size);
} else if (role == SocketRole::Connected) { }
if (m_client_closed) if (role == SocketRole::Connected) {
if (!m_connected_fds_open)
return -EPIPE; return -EPIPE;
return m_for_server.write(data, size); return m_for_server.write(data, size);
} }
@ -152,8 +164,8 @@ ssize_t LocalSocket::write(SocketRole role, const byte* data, size_t size)
bool LocalSocket::can_write(SocketRole role) const bool LocalSocket::can_write(SocketRole role) const
{ {
if (role == SocketRole::Accepted) if (role == SocketRole::Accepted)
return m_client_closed || m_for_client.bytes_in_write_buffer() < 4096; return !m_connected_fds_open || m_for_client.bytes_in_write_buffer() < 4096;
else if (role == SocketRole::Connected) if (role == SocketRole::Connected)
return m_server_closed || m_for_server.bytes_in_write_buffer() < 4096; return !m_accepted_fds_open || m_for_server.bytes_in_write_buffer() < 4096;
ASSERT_NOT_REACHED(); ASSERT_NOT_REACHED();
} }

View file

@ -13,7 +13,8 @@ public:
virtual bool bind(const sockaddr*, socklen_t, int& error) override; virtual bool bind(const sockaddr*, socklen_t, int& error) override;
virtual bool connect(const sockaddr*, socklen_t, int& error) override; virtual bool connect(const sockaddr*, socklen_t, int& error) override;
virtual bool get_address(sockaddr*, socklen_t*) override; virtual bool get_address(sockaddr*, socklen_t*) override;
virtual void close(SocketRole) override; virtual void attach_fd(SocketRole) override;
virtual void detach_fd(SocketRole) override;
virtual bool can_read(SocketRole) const override; virtual bool can_read(SocketRole) const override;
virtual ssize_t read(SocketRole, byte*, size_t) override; virtual ssize_t read(SocketRole, byte*, size_t) override;
virtual ssize_t write(SocketRole, const byte*, size_t) override; virtual ssize_t write(SocketRole, const byte*, size_t) override;
@ -27,8 +28,8 @@ private:
RetainPtr<LocalSocket> m_peer; RetainPtr<LocalSocket> m_peer;
bool m_bound { false }; bool m_bound { false };
bool m_server_closed { false }; int m_accepted_fds_open { 0 };
bool m_client_closed { false }; int m_connected_fds_open { 0 };
sockaddr_un m_address; sockaddr_un m_address;
DoubleBuffer m_for_client; DoubleBuffer m_for_client;

View file

@ -36,7 +36,6 @@ bool Socket::listen(int backlog, int& error)
return false; return false;
} }
m_backlog = backlog; m_backlog = backlog;
m_listening = true;
kprintf("Socket{%p} listening with backlog=%d\n", this, m_backlog); kprintf("Socket{%p} listening with backlog=%d\n", this, m_backlog);
return true; return true;
} }

View file

@ -14,7 +14,6 @@ public:
static RetainPtr<Socket> create(int domain, int type, int protocol, int& error); static RetainPtr<Socket> create(int domain, int type, int protocol, int& error);
virtual ~Socket(); virtual ~Socket();
bool is_listening() const { return m_listening; }
int domain() const { return m_domain; } int domain() const { return m_domain; }
int type() const { return m_type; } int type() const { return m_type; }
int protocol() const { return m_protocol; } int protocol() const { return m_protocol; }
@ -28,7 +27,8 @@ public:
virtual bool connect(const sockaddr*, socklen_t, int& error) = 0; virtual bool connect(const sockaddr*, socklen_t, int& error) = 0;
virtual bool get_address(sockaddr*, socklen_t*) = 0; virtual bool get_address(sockaddr*, socklen_t*) = 0;
virtual bool is_local() const { return false; } virtual bool is_local() const { return false; }
virtual void close(SocketRole) = 0; virtual void attach_fd(SocketRole) = 0;
virtual void detach_fd(SocketRole) = 0;
virtual bool can_read(SocketRole) const = 0; virtual bool can_read(SocketRole) const = 0;
virtual ssize_t read(SocketRole, byte*, size_t) = 0; virtual ssize_t read(SocketRole, byte*, size_t) = 0;
virtual ssize_t write(SocketRole, const byte*, size_t) = 0; virtual ssize_t write(SocketRole, const byte*, size_t) = 0;
@ -48,7 +48,6 @@ private:
int m_type { 0 }; int m_type { 0 };
int m_protocol { 0 }; int m_protocol { 0 };
int m_backlog { 0 }; int m_backlog { 0 };
bool m_listening { false };
bool m_connected { false }; bool m_connected { false };
Vector<RetainPtr<Socket>> m_pending; Vector<RetainPtr<Socket>> m_pending;