diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp index d9453c5e1f..11e18fd7e9 100644 --- a/Kernel/Net/LocalSocket.cpp +++ b/Kernel/Net/LocalSocket.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -7,6 +8,21 @@ //#define DEBUG_LOCAL_SOCKET +Lockable>& LocalSocket::all_sockets() +{ + static Lockable>* s_list; + if (!s_list) + s_list = new Lockable>(); + return *s_list; +} + +void LocalSocket::for_each(Function callback) +{ + LOCKER(all_sockets().lock()); + for (auto& socket : all_sockets().resource()) + callback(socket); +} + NonnullRefPtr LocalSocket::create(int type) { return adopt(*new LocalSocket(type)); @@ -15,6 +31,8 @@ NonnullRefPtr LocalSocket::create(int type) LocalSocket::LocalSocket(int type) : Socket(AF_LOCAL, type, 0) { + LOCKER(all_sockets().lock()); + all_sockets().resource().append(this); #ifdef DEBUG_LOCAL_SOCKET kprintf("%s(%u) LocalSocket{%p} created with type=%u\n", current->process().name().characters(), current->pid(), this, type); #endif @@ -22,6 +40,8 @@ LocalSocket::LocalSocket(int type) LocalSocket::~LocalSocket() { + LOCKER(all_sockets().lock()); + all_sockets().resource().remove(this); } bool LocalSocket::get_local_address(sockaddr* address, socklen_t* address_size) @@ -91,6 +111,7 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre auto description_or_error = VFS::the().open(safe_address, 0, 0, current->process().current_directory()); if (description_or_error.is_error()) return KResult(-ECONNREFUSED); + m_file = move(description_or_error.value()); ASSERT(m_file->inode()); diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h index 84c470173d..5eac9f7d80 100644 --- a/Kernel/Net/LocalSocket.h +++ b/Kernel/Net/LocalSocket.h @@ -1,15 +1,18 @@ #pragma once +#include #include #include class FileDescription; -class LocalSocket final : public Socket { +class LocalSocket final : public Socket, public InlineLinkedListNode { + friend class InlineLinkedListNode; public: static NonnullRefPtr create(int type); virtual ~LocalSocket() override; + static void for_each(Function); StringView socket_path() const; // ^Socket @@ -30,6 +33,7 @@ private: virtual const char* class_name() const override { return "LocalSocket"; } virtual bool is_local() const override { return true; } bool has_attached_peer(const FileDescription&) const; + static Lockable>& all_sockets(); // An open socket file on the filesystem. RefPtr m_file; @@ -54,4 +58,8 @@ private: DoubleBuffer m_for_client; DoubleBuffer m_for_server; + + // for InlineLinkedList + LocalSocket* m_prev { nullptr }; + LocalSocket* m_next { nullptr }; };