diff --git a/Kernel/FileDescriptor.cpp b/Kernel/FileDescriptor.cpp index 53560b0e8b..7d842c4196 100644 --- a/Kernel/FileDescriptor.cpp +++ b/Kernel/FileDescriptor.cpp @@ -49,14 +49,14 @@ FileDescriptor::FileDescriptor(RetainPtr&& device) FileDescriptor::FileDescriptor(RetainPtr&& socket, SocketRole role) : m_socket(move(socket)) - , m_socket_role(role) { + set_socket_role(role); } FileDescriptor::~FileDescriptor() { if (m_socket) { - m_socket->close(m_socket_role); + m_socket->detach_fd(m_socket_role); m_socket = nullptr; } if (m_device) { @@ -70,6 +70,16 @@ FileDescriptor::~FileDescriptor() m_inode = nullptr; } +void FileDescriptor::set_socket_role(SocketRole role) +{ + if (role == m_socket_role) + return; + + ASSERT(m_socket); + m_socket_role = role; + m_socket->attach_fd(role); +} + RetainPtr FileDescriptor::clone() { RetainPtr descriptor; @@ -81,6 +91,9 @@ RetainPtr FileDescriptor::clone() if (m_device) { descriptor = FileDescriptor::create(m_device.copy_ref()); descriptor->m_inode = m_inode.copy_ref(); + } else if (m_socket) { + descriptor = FileDescriptor::create(m_socket.copy_ref(), m_socket_role); + descriptor->m_inode = m_inode.copy_ref(); } else { descriptor = FileDescriptor::create(m_inode.copy_ref()); } diff --git a/Kernel/FileDescriptor.h b/Kernel/FileDescriptor.h index 760f543199..b069875d20 100644 --- a/Kernel/FileDescriptor.h +++ b/Kernel/FileDescriptor.h @@ -86,7 +86,7 @@ public: void set_original_inode(Badge, RetainPtr&& inode) { m_inode = move(inode); } - void set_socket_role(SocketRole role) { m_socket_role = role; } + void set_socket_role(SocketRole); private: friend class VFS; diff --git a/Kernel/LocalSocket.cpp b/Kernel/LocalSocket.cpp index cfc9e23d95..8ee414dd79 100644 --- a/Kernel/LocalSocket.cpp +++ b/Kernel/LocalSocket.cpp @@ -106,23 +106,34 @@ bool LocalSocket::connect(const sockaddr* address, socklen_t address_size, int& return true; } -void LocalSocket::close(SocketRole role) +void LocalSocket::attach_fd(SocketRole role) { - if (role == SocketRole::Accepted) - m_server_closed = true; - else if (role == SocketRole::Connected) - m_client_closed = true; + if (role == SocketRole::Accepted) { + ++m_accepted_fds_open; + } else if (role == SocketRole::Connected) { + ++m_connected_fds_open; + } +} + +void LocalSocket::detach_fd(SocketRole role) +{ + if (role == SocketRole::Accepted) { + ASSERT(m_accepted_fds_open); + --m_accepted_fds_open; + } else if (role == SocketRole::Connected) { + ASSERT(m_connected_fds_open); + --m_connected_fds_open; + } } bool LocalSocket::can_read(SocketRole role) const { - if (m_bound && is_listening()) + if (role == SocketRole::Listener) return can_accept(); - if (role == SocketRole::Accepted) - return m_client_closed || !m_for_server.is_empty(); - else if (role == SocketRole::Connected) - return m_server_closed || !m_for_client.is_empty(); + return !m_connected_fds_open || !m_for_server.is_empty(); + if (role == SocketRole::Connected) + return !m_accepted_fds_open || !m_for_client.is_empty(); ASSERT_NOT_REACHED(); } @@ -130,7 +141,7 @@ ssize_t LocalSocket::read(SocketRole role, byte* buffer, size_t size) { if (role == SocketRole::Accepted) return m_for_server.read(buffer, size); - else if (role == SocketRole::Connected) + if (role == SocketRole::Connected) return m_for_client.read(buffer, size); ASSERT_NOT_REACHED(); } @@ -138,11 +149,12 @@ ssize_t LocalSocket::read(SocketRole role, byte* buffer, size_t size) ssize_t LocalSocket::write(SocketRole role, const byte* data, size_t size) { if (role == SocketRole::Accepted) { - if (m_client_closed) + if (!m_accepted_fds_open) return -EPIPE; return m_for_client.write(data, size); - } else if (role == SocketRole::Connected) { - if (m_client_closed) + } + if (role == SocketRole::Connected) { + if (!m_connected_fds_open) return -EPIPE; return m_for_server.write(data, size); } @@ -152,8 +164,8 @@ ssize_t LocalSocket::write(SocketRole role, const byte* data, size_t size) bool LocalSocket::can_write(SocketRole role) const { if (role == SocketRole::Accepted) - return m_client_closed || m_for_client.bytes_in_write_buffer() < 4096; - else if (role == SocketRole::Connected) - return m_server_closed || m_for_server.bytes_in_write_buffer() < 4096; + return !m_connected_fds_open || m_for_client.bytes_in_write_buffer() < 4096; + if (role == SocketRole::Connected) + return !m_accepted_fds_open || m_for_server.bytes_in_write_buffer() < 4096; ASSERT_NOT_REACHED(); } diff --git a/Kernel/LocalSocket.h b/Kernel/LocalSocket.h index a56d42d7df..51ceaeccd8 100644 --- a/Kernel/LocalSocket.h +++ b/Kernel/LocalSocket.h @@ -13,7 +13,8 @@ public: virtual bool bind(const sockaddr*, socklen_t, int& error) override; virtual bool connect(const sockaddr*, socklen_t, int& error) override; virtual bool get_address(sockaddr*, socklen_t*) override; - virtual void close(SocketRole) override; + virtual void attach_fd(SocketRole) override; + virtual void detach_fd(SocketRole) override; virtual bool can_read(SocketRole) const override; virtual ssize_t read(SocketRole, byte*, size_t) override; virtual ssize_t write(SocketRole, const byte*, size_t) override; @@ -27,8 +28,8 @@ private: RetainPtr m_peer; bool m_bound { false }; - bool m_server_closed { false }; - bool m_client_closed { false }; + int m_accepted_fds_open { 0 }; + int m_connected_fds_open { 0 }; sockaddr_un m_address; DoubleBuffer m_for_client; diff --git a/Kernel/Socket.cpp b/Kernel/Socket.cpp index 7262fffcce..e6cf5dd26b 100644 --- a/Kernel/Socket.cpp +++ b/Kernel/Socket.cpp @@ -36,7 +36,6 @@ bool Socket::listen(int backlog, int& error) return false; } m_backlog = backlog; - m_listening = true; kprintf("Socket{%p} listening with backlog=%d\n", this, m_backlog); return true; } diff --git a/Kernel/Socket.h b/Kernel/Socket.h index 285db684e0..33767b6ad7 100644 --- a/Kernel/Socket.h +++ b/Kernel/Socket.h @@ -14,7 +14,6 @@ public: static RetainPtr create(int domain, int type, int protocol, int& error); virtual ~Socket(); - bool is_listening() const { return m_listening; } int domain() const { return m_domain; } int type() const { return m_type; } int protocol() const { return m_protocol; } @@ -28,7 +27,8 @@ public: virtual bool connect(const sockaddr*, socklen_t, int& error) = 0; virtual bool get_address(sockaddr*, socklen_t*) = 0; virtual bool is_local() const { return false; } - virtual void close(SocketRole) = 0; + virtual void attach_fd(SocketRole) = 0; + virtual void detach_fd(SocketRole) = 0; virtual bool can_read(SocketRole) const = 0; virtual ssize_t read(SocketRole, byte*, size_t) = 0; virtual ssize_t write(SocketRole, const byte*, size_t) = 0; @@ -48,7 +48,6 @@ private: int m_type { 0 }; int m_protocol { 0 }; int m_backlog { 0 }; - bool m_listening { false }; bool m_connected { false }; Vector> m_pending;