diff --git a/AK/NonnullRefPtr.h b/AK/NonnullRefPtr.h index 094d8bba44..fb2632a66e 100644 --- a/AK/NonnullRefPtr.h +++ b/AK/NonnullRefPtr.h @@ -6,15 +6,14 @@ #pragma once -#include -#include -#include -#include -#include #ifdef KERNEL -# include -# include -#endif +# include +#else +# include +# include +# include +# include +# include namespace AK { @@ -51,52 +50,55 @@ public: enum AdoptTag { Adopt }; - ALWAYS_INLINE NonnullRefPtr(const T& object) - : m_bits((FlatPtr)&object) + ALWAYS_INLINE NonnullRefPtr(T const& object) + : m_ptr(const_cast(&object)) { - VERIFY(!(m_bits & 1)); - const_cast(object).ref(); + m_ptr->ref(); } + template - ALWAYS_INLINE NonnullRefPtr(const U& object) requires(IsConvertible) - : m_bits((FlatPtr) static_cast(&object)) + ALWAYS_INLINE NonnullRefPtr(U const& object) requires(IsConvertible) + : m_ptr(const_cast(static_cast(&object))) { - VERIFY(!(m_bits & 1)); - const_cast(static_cast(object)).ref(); + m_ptr->ref(); } + ALWAYS_INLINE NonnullRefPtr(AdoptTag, T& object) - : m_bits((FlatPtr)&object) + : m_ptr(&object) { - VERIFY(!(m_bits & 1)); } + ALWAYS_INLINE NonnullRefPtr(NonnullRefPtr&& other) - : m_bits((FlatPtr)&other.leak_ref()) + : m_ptr(&other.leak_ref()) { - VERIFY(!(m_bits & 1)); } + template ALWAYS_INLINE NonnullRefPtr(NonnullRefPtr&& other) requires(IsConvertible) - : m_bits((FlatPtr)&other.leak_ref()) + : m_ptr(static_cast(&other.leak_ref())) { - VERIFY(!(m_bits & 1)); } - ALWAYS_INLINE NonnullRefPtr(const NonnullRefPtr& other) - : m_bits((FlatPtr)other.add_ref()) + + ALWAYS_INLINE NonnullRefPtr(NonnullRefPtr const& other) + : m_ptr(const_cast(other.ptr())) { - VERIFY(!(m_bits & 1)); + m_ptr->ref(); } + template - ALWAYS_INLINE NonnullRefPtr(const NonnullRefPtr& other) requires(IsConvertible) - : m_bits((FlatPtr)other.add_ref()) + ALWAYS_INLINE NonnullRefPtr(NonnullRefPtr const& other) requires(IsConvertible) + : m_ptr(const_cast(static_cast(other.ptr()))) { - VERIFY(!(m_bits & 1)); + m_ptr->ref(); } + ALWAYS_INLINE ~NonnullRefPtr() { - assign(nullptr); -#ifdef SANITIZE_PTRS - m_bits.store(explode_byte(0xb0), AK::MemoryOrder::memory_order_relaxed); -#endif + unref_if_not_null(m_ptr); + m_ptr = nullptr; +# ifdef SANITIZE_PTRS + m_ptr = reinterpret_cast(explode_byte(0xb0)); +# endif } template @@ -111,44 +113,46 @@ public: NonnullRefPtr(const RefPtr&) = delete; NonnullRefPtr& operator=(const RefPtr&) = delete; - NonnullRefPtr& operator=(const NonnullRefPtr& other) + NonnullRefPtr& operator=(NonnullRefPtr const& other) { - if (this != &other) - assign(other.add_ref()); + NonnullRefPtr tmp { other }; + swap(tmp); return *this; } template - NonnullRefPtr& operator=(const NonnullRefPtr& other) requires(IsConvertible) + NonnullRefPtr& operator=(NonnullRefPtr const& other) requires(IsConvertible) { - assign(other.add_ref()); + NonnullRefPtr tmp { other }; + swap(tmp); return *this; } ALWAYS_INLINE NonnullRefPtr& operator=(NonnullRefPtr&& other) { - if (this != &other) - assign(&other.leak_ref()); + NonnullRefPtr tmp { move(other) }; + swap(tmp); return *this; } template NonnullRefPtr& operator=(NonnullRefPtr&& other) requires(IsConvertible) { - assign(&other.leak_ref()); + NonnullRefPtr tmp { move(other) }; + swap(tmp); return *this; } - NonnullRefPtr& operator=(const T& object) + NonnullRefPtr& operator=(T const& object) { - const_cast(object).ref(); - assign(const_cast(&object)); + NonnullRefPtr tmp { object }; + swap(tmp); return *this; } [[nodiscard]] ALWAYS_INLINE T& leak_ref() { - T* ptr = exchange(nullptr); + T* ptr = exchange(m_ptr, nullptr); VERIFY(ptr); return *ptr; } @@ -203,113 +207,24 @@ public: void swap(NonnullRefPtr& other) { - if (this == &other) - return; - - // NOTE: swap is not atomic! - T* other_ptr = other.exchange(nullptr); - T* ptr = exchange(other_ptr); - other.exchange(ptr); + AK::swap(m_ptr, other.m_ptr); } template void swap(NonnullRefPtr& other) requires(IsConvertible) { - // NOTE: swap is not atomic! - U* other_ptr = other.exchange(nullptr); - T* ptr = exchange(other_ptr); - other.exchange(ptr); + AK::swap(m_ptr, other.m_ptr); } private: NonnullRefPtr() = delete; - - ALWAYS_INLINE T* as_ptr() const - { - return (T*)(m_bits.load(AK::MemoryOrder::memory_order_relaxed) & ~(FlatPtr)1); - } - ALWAYS_INLINE RETURNS_NONNULL T* as_nonnull_ptr() const { - T* ptr = (T*)(m_bits.load(AK::MemoryOrder::memory_order_relaxed) & ~(FlatPtr)1); - VERIFY(ptr); - return ptr; + VERIFY(m_ptr); + return m_ptr; } - template - void do_while_locked(F f) const - { -#ifdef KERNEL - // We don't want to be pre-empted while we have the lock bit set - Kernel::ScopedCritical critical; -#endif - FlatPtr bits; - for (;;) { - bits = m_bits.fetch_or(1, AK::MemoryOrder::memory_order_acq_rel); - if (!(bits & 1)) - break; -#ifdef KERNEL - Kernel::Processor::wait_check(); -#endif - } - VERIFY(!(bits & 1)); - f((T*)bits); - m_bits.store(bits, AK::MemoryOrder::memory_order_release); - } - - ALWAYS_INLINE void assign(T* new_ptr) - { - T* prev_ptr = exchange(new_ptr); - unref_if_not_null(prev_ptr); - } - - ALWAYS_INLINE T* exchange(T* new_ptr) - { - VERIFY(!((FlatPtr)new_ptr & 1)); -#ifdef KERNEL - // We don't want to be pre-empted while we have the lock bit set - Kernel::ScopedCritical critical; -#endif - // Only exchange while not locked - FlatPtr expected = m_bits.load(AK::MemoryOrder::memory_order_relaxed); - for (;;) { - expected &= ~(FlatPtr)1; // only if lock bit is not set - if (m_bits.compare_exchange_strong(expected, (FlatPtr)new_ptr, AK::MemoryOrder::memory_order_acq_rel)) - break; -#ifdef KERNEL - Kernel::Processor::wait_check(); -#endif - } - VERIFY(!(expected & 1)); - return (T*)expected; - } - - T* add_ref() const - { -#ifdef KERNEL - // We don't want to be pre-empted while we have the lock bit set - Kernel::ScopedCritical critical; -#endif - // Lock the pointer - FlatPtr expected = m_bits.load(AK::MemoryOrder::memory_order_relaxed); - for (;;) { - expected &= ~(FlatPtr)1; // only if lock bit is not set - if (m_bits.compare_exchange_strong(expected, expected | 1, AK::MemoryOrder::memory_order_acq_rel)) - break; -#ifdef KERNEL - Kernel::Processor::wait_check(); -#endif - } - - // Add a reference now that we locked the pointer - ref_if_not_null((T*)expected); - - // Unlock the pointer again - m_bits.store(expected, AK::MemoryOrder::memory_order_release); - return (T*)expected; - } - - mutable Atomic m_bits { 0 }; + T* m_ptr { nullptr }; }; template @@ -357,3 +272,5 @@ struct Traits> : public GenericTraits> { using AK::adopt_ref; using AK::make_ref_counted; using AK::NonnullRefPtr; + +#endif diff --git a/AK/RefCounted.h b/AK/RefCounted.h index 802dd8574a..3417333d1a 100644 --- a/AK/RefCounted.h +++ b/AK/RefCounted.h @@ -6,12 +6,15 @@ #pragma once -#include -#include -#include -#include -#include -#include +#ifdef KERNEL +# include +#else + +# include +# include +# include +# include +# include namespace AK { @@ -49,43 +52,32 @@ public: void ref() const { - auto old_ref_count = m_ref_count.fetch_add(1, AK::MemoryOrder::memory_order_relaxed); - VERIFY(old_ref_count > 0); - VERIFY(!Checked::addition_would_overflow(old_ref_count, 1)); + VERIFY(m_ref_count > 0); + VERIFY(!Checked::addition_would_overflow(m_ref_count, 1)); + ++m_ref_count; } [[nodiscard]] bool try_ref() const { - RefCountType expected = m_ref_count.load(AK::MemoryOrder::memory_order_relaxed); - for (;;) { - if (expected == 0) - return false; - VERIFY(!Checked::addition_would_overflow(expected, 1)); - if (m_ref_count.compare_exchange_strong(expected, expected + 1, AK::MemoryOrder::memory_order_acquire)) - return true; - } + if (m_ref_count == 0) + return false; + ref(); + return true; } - [[nodiscard]] RefCountType ref_count() const - { - return m_ref_count.load(AK::MemoryOrder::memory_order_relaxed); - } + [[nodiscard]] RefCountType ref_count() const { return m_ref_count; } protected: RefCountedBase() = default; - ~RefCountedBase() - { - VERIFY(m_ref_count.load(AK::MemoryOrder::memory_order_relaxed) == 0); - } + ~RefCountedBase() { VERIFY(!m_ref_count); } RefCountType deref_base() const { - auto old_ref_count = m_ref_count.fetch_sub(1, AK::MemoryOrder::memory_order_acq_rel); - VERIFY(old_ref_count > 0); - return old_ref_count - 1; + VERIFY(m_ref_count); + return --m_ref_count; } - mutable Atomic m_ref_count { 1 }; + RefCountType mutable m_ref_count { 1 }; }; template @@ -109,3 +101,5 @@ public: using AK::RefCounted; using AK::RefCountedBase; + +#endif diff --git a/AK/RefPtr.h b/AK/RefPtr.h index e8fa9fc5ec..7076a2da1a 100644 --- a/AK/RefPtr.h +++ b/AK/RefPtr.h @@ -6,114 +6,23 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include #ifdef KERNEL -# include -# include -# include -#endif +# include +#else + +# include +# include +# include +# include +# include +# include +# include namespace AK { template class OwnPtr; -template -struct RefPtrTraits { - ALWAYS_INLINE static T* as_ptr(FlatPtr bits) - { - return (T*)(bits & ~(FlatPtr)1); - } - - ALWAYS_INLINE static FlatPtr as_bits(T* ptr) - { - VERIFY(!((FlatPtr)ptr & 1)); - return (FlatPtr)ptr; - } - - template - ALWAYS_INLINE static FlatPtr convert_from(FlatPtr bits) - { - if (PtrTraits::is_null(bits)) - return default_null_value; - return as_bits(PtrTraits::as_ptr(bits)); - } - - ALWAYS_INLINE static bool is_null(FlatPtr bits) - { - return !(bits & ~(FlatPtr)1); - } - - ALWAYS_INLINE static FlatPtr exchange(Atomic& atomic_var, FlatPtr new_value) - { - // Only exchange when lock is not held - VERIFY(!(new_value & 1)); - FlatPtr expected = atomic_var.load(AK::MemoryOrder::memory_order_relaxed); - for (;;) { - expected &= ~(FlatPtr)1; // only if lock bit is not set - if (atomic_var.compare_exchange_strong(expected, new_value, AK::MemoryOrder::memory_order_acq_rel)) - break; -#ifdef KERNEL - Kernel::Processor::wait_check(); -#endif - } - return expected; - } - - ALWAYS_INLINE static bool exchange_if_null(Atomic& atomic_var, FlatPtr new_value) - { - // Only exchange when lock is not held - VERIFY(!(new_value & 1)); - for (;;) { - FlatPtr expected = default_null_value; // only if lock bit is not set - if (atomic_var.compare_exchange_strong(expected, new_value, AK::MemoryOrder::memory_order_acq_rel)) - break; - if (!is_null(expected)) - return false; -#ifdef KERNEL - Kernel::Processor::wait_check(); -#endif - } - return true; - } - - ALWAYS_INLINE static FlatPtr lock(Atomic& atomic_var) - { - // This sets the lock bit atomically, preventing further modifications. - // This is important when e.g. copying a RefPtr where the source - // might be released and freed too quickly. This allows us - // to temporarily lock the pointer so we can add a reference, then - // unlock it - FlatPtr bits; - for (;;) { - bits = atomic_var.fetch_or(1, AK::MemoryOrder::memory_order_acq_rel); - if (!(bits & 1)) - break; -#ifdef KERNEL - Kernel::Processor::wait_check(); -#endif - } - VERIFY(!(bits & 1)); - return bits; - } - - ALWAYS_INLINE static void unlock(Atomic& atomic_var, FlatPtr new_value) - { - VERIFY(!(new_value & 1)); - atomic_var.store(new_value, AK::MemoryOrder::memory_order_release); - } - - static constexpr FlatPtr default_null_value = 0; - - using NullType = std::nullptr_t; -}; - template class RefPtr { template @@ -127,149 +36,154 @@ public: }; RefPtr() = default; - RefPtr(const T* ptr) - : m_bits(PtrTraits::as_bits(const_cast(ptr))) + RefPtr(T const* ptr) + : m_ptr(const_cast(ptr)) { - ref_if_not_null(const_cast(ptr)); + ref_if_not_null(m_ptr); } - RefPtr(const T& object) - : m_bits(PtrTraits::as_bits(const_cast(&object))) + + RefPtr(T const& object) + : m_ptr(const_cast(&object)) { - T* ptr = const_cast(&object); - VERIFY(ptr); - VERIFY(!is_null()); - ptr->ref(); + m_ptr->ref(); } + RefPtr(AdoptTag, T& object) - : m_bits(PtrTraits::as_bits(&object)) + : m_ptr(&object) { - VERIFY(!is_null()); } + RefPtr(RefPtr&& other) - : m_bits(other.leak_ref_raw()) + : m_ptr(other.leak_ref()) { } - ALWAYS_INLINE RefPtr(const NonnullRefPtr& other) - : m_bits(PtrTraits::as_bits(const_cast(other.add_ref()))) + + ALWAYS_INLINE RefPtr(NonnullRefPtr const& other) + : m_ptr(const_cast(other.ptr())) { + m_ptr->ref(); } + template - ALWAYS_INLINE RefPtr(const NonnullRefPtr& other) requires(IsConvertible) - : m_bits(PtrTraits::as_bits(const_cast(other.add_ref()))) + ALWAYS_INLINE RefPtr(NonnullRefPtr const& other) requires(IsConvertible) + : m_ptr(const_cast(static_cast(other.ptr()))) { + m_ptr->ref(); } + template ALWAYS_INLINE RefPtr(NonnullRefPtr&& other) requires(IsConvertible) - : m_bits(PtrTraits::as_bits(&other.leak_ref())) + : m_ptr(static_cast(&other.leak_ref())) { - VERIFY(!is_null()); } + template> RefPtr(RefPtr&& other) requires(IsConvertible) - : m_bits(PtrTraits::template convert_from(other.leak_ref_raw())) + : m_ptr(static_cast(other.leak_ref())) { } - RefPtr(const RefPtr& other) - : m_bits(other.add_ref_raw()) + + RefPtr(RefPtr const& other) + : m_ptr(other.m_ptr) { + ref_if_not_null(m_ptr); } + template> - RefPtr(const RefPtr& other) requires(IsConvertible) - : m_bits(other.add_ref_raw()) + RefPtr(RefPtr const& other) requires(IsConvertible) + : m_ptr(const_cast(static_cast(other.ptr()))) { + ref_if_not_null(m_ptr); } + ALWAYS_INLINE ~RefPtr() { clear(); -#ifdef SANITIZE_PTRS - m_bits.store(explode_byte(0xe0), AK::MemoryOrder::memory_order_relaxed); -#endif +# ifdef SANITIZE_PTRS + m_ptr = reinterpret_cast(explode_byte(0xe0)); +# endif } template - RefPtr(const OwnPtr&) = delete; + RefPtr(OwnPtr const&) = delete; template - RefPtr& operator=(const OwnPtr&) = delete; + RefPtr& operator=(OwnPtr const&) = delete; void swap(RefPtr& other) { - if (this == &other) - return; - - // NOTE: swap is not atomic! - FlatPtr other_bits = PtrTraits::exchange(other.m_bits, PtrTraits::default_null_value); - FlatPtr bits = PtrTraits::exchange(m_bits, other_bits); - PtrTraits::exchange(other.m_bits, bits); + AK::swap(m_ptr, other.m_ptr); } template> void swap(RefPtr& other) requires(IsConvertible) { - // NOTE: swap is not atomic! - FlatPtr other_bits = P::exchange(other.m_bits, P::default_null_value); - FlatPtr bits = PtrTraits::exchange(m_bits, PtrTraits::template convert_from(other_bits)); - P::exchange(other.m_bits, P::template convert_from(bits)); + AK::swap(m_ptr, other.m_ptr); } ALWAYS_INLINE RefPtr& operator=(RefPtr&& other) { - if (this != &other) - assign_raw(other.leak_ref_raw()); + RefPtr tmp { move(other) }; + swap(tmp); return *this; } template> ALWAYS_INLINE RefPtr& operator=(RefPtr&& other) requires(IsConvertible) { - assign_raw(PtrTraits::template convert_from(other.leak_ref_raw())); + RefPtr tmp { move(other) }; + swap(tmp); return *this; } template ALWAYS_INLINE RefPtr& operator=(NonnullRefPtr&& other) requires(IsConvertible) { - assign_raw(PtrTraits::as_bits(&other.leak_ref())); + RefPtr tmp { move(other) }; + swap(tmp); return *this; } - ALWAYS_INLINE RefPtr& operator=(const NonnullRefPtr& other) + ALWAYS_INLINE RefPtr& operator=(NonnullRefPtr const& other) { - assign_raw(PtrTraits::as_bits(other.add_ref())); + RefPtr tmp { other }; + swap(tmp); return *this; } template - ALWAYS_INLINE RefPtr& operator=(const NonnullRefPtr& other) requires(IsConvertible) + ALWAYS_INLINE RefPtr& operator=(NonnullRefPtr const& other) requires(IsConvertible) { - assign_raw(PtrTraits::as_bits(other.add_ref())); + RefPtr tmp { other }; + swap(tmp); return *this; } - ALWAYS_INLINE RefPtr& operator=(const RefPtr& other) + ALWAYS_INLINE RefPtr& operator=(RefPtr const& other) { - if (this != &other) - assign_raw(other.add_ref_raw()); + RefPtr tmp { other }; + swap(tmp); return *this; } template - ALWAYS_INLINE RefPtr& operator=(const RefPtr& other) requires(IsConvertible) + ALWAYS_INLINE RefPtr& operator=(RefPtr const& other) requires(IsConvertible) { - assign_raw(other.add_ref_raw()); + RefPtr tmp { other }; + swap(tmp); return *this; } - ALWAYS_INLINE RefPtr& operator=(const T* ptr) + ALWAYS_INLINE RefPtr& operator=(T const* ptr) { - ref_if_not_null(const_cast(ptr)); - assign_raw(PtrTraits::as_bits(const_cast(ptr))); + RefPtr tmp { ptr }; + swap(tmp); return *this; } - ALWAYS_INLINE RefPtr& operator=(const T& object) + ALWAYS_INLINE RefPtr& operator=(T const& object) { - const_cast(object).ref(); - assign_raw(PtrTraits::as_bits(const_cast(&object))); + RefPtr tmp { object }; + swap(tmp); return *this; } @@ -283,7 +197,8 @@ public: { if (this == &other) return is_null(); - return PtrTraits::exchange_if_null(m_bits, other.leak_ref_raw()); + *this = move(other); + return true; } template> @@ -291,27 +206,28 @@ public: { if (this == &other) return is_null(); - return PtrTraits::exchange_if_null(m_bits, PtrTraits::template convert_from(other.leak_ref_raw())); + *this = move(other); + return true; } ALWAYS_INLINE void clear() { - assign_raw(PtrTraits::default_null_value); + unref_if_not_null(m_ptr); + m_ptr = nullptr; } - bool operator!() const { return PtrTraits::is_null(m_bits.load(AK::MemoryOrder::memory_order_relaxed)); } + bool operator!() const { return !m_ptr; } [[nodiscard]] T* leak_ref() { - FlatPtr bits = PtrTraits::exchange(m_bits, PtrTraits::default_null_value); - return PtrTraits::as_ptr(bits); + return exchange(m_ptr, nullptr); } NonnullRefPtr release_nonnull() { - FlatPtr bits = PtrTraits::exchange(m_bits, PtrTraits::default_null_value); - VERIFY(!PtrTraits::is_null(bits)); - return NonnullRefPtr(NonnullRefPtr::Adopt, *PtrTraits::as_ptr(bits)); + auto* ptr = leak_ref(); + VERIFY(ptr); + return NonnullRefPtr(NonnullRefPtr::Adopt, *ptr); } ALWAYS_INLINE T* ptr() { return as_ptr(); } @@ -357,88 +273,21 @@ public: bool operator==(T* other) { return as_ptr() == other; } bool operator!=(T* other) { return as_ptr() != other; } - ALWAYS_INLINE bool is_null() const { return PtrTraits::is_null(m_bits.load(AK::MemoryOrder::memory_order_relaxed)); } - - template && !IsNullPointer>::Type* = nullptr> - typename PtrTraits::NullType null_value() const - { - // make sure we are holding a null value - FlatPtr bits = m_bits.load(AK::MemoryOrder::memory_order_relaxed); - VERIFY(PtrTraits::is_null(bits)); - return PtrTraits::to_null_value(bits); - } - template && !IsNullPointer>::Type* = nullptr> - void set_null_value(typename PtrTraits::NullType value) - { - // make sure that new null value would be interpreted as a null value - FlatPtr bits = PtrTraits::from_null_value(value); - VERIFY(PtrTraits::is_null(bits)); - assign_raw(bits); - } + ALWAYS_INLINE bool is_null() const { return !m_ptr; } private: - template - void do_while_locked(F f) const - { -#ifdef KERNEL - // We don't want to be pre-empted while we have the lock bit set - Kernel::ScopedCritical critical; -#endif - FlatPtr bits = PtrTraits::lock(m_bits); - T* ptr = PtrTraits::as_ptr(bits); - f(ptr); - PtrTraits::unlock(m_bits, bits); - } - - [[nodiscard]] ALWAYS_INLINE FlatPtr leak_ref_raw() - { - return PtrTraits::exchange(m_bits, PtrTraits::default_null_value); - } - - [[nodiscard]] ALWAYS_INLINE FlatPtr add_ref_raw() const - { -#ifdef KERNEL - // We don't want to be pre-empted while we have the lock bit set - Kernel::ScopedCritical critical; -#endif - // This prevents a race condition between thread A and B: - // 1. Thread A copies RefPtr, e.g. through assignment or copy constructor, - // gets the pointer from source, but is pre-empted before adding - // another reference - // 2. Thread B calls clear, leak_ref, or release_nonnull on source, and - // then drops the last reference, causing the object to be deleted - // 3. Thread A finishes step #1 by attempting to add a reference to - // the object that was already deleted in step #2 - FlatPtr bits = PtrTraits::lock(m_bits); - if (T* ptr = PtrTraits::as_ptr(bits)) - ptr->ref(); - PtrTraits::unlock(m_bits, bits); - return bits; - } - - ALWAYS_INLINE void assign_raw(FlatPtr bits) - { - FlatPtr prev_bits = PtrTraits::exchange(m_bits, bits); - unref_if_not_null(PtrTraits::as_ptr(prev_bits)); - } - ALWAYS_INLINE T* as_ptr() const { - return PtrTraits::as_ptr(m_bits.load(AK::MemoryOrder::memory_order_relaxed)); + return m_ptr; } ALWAYS_INLINE T* as_nonnull_ptr() const { - return as_nonnull_ptr(m_bits.load(AK::MemoryOrder::memory_order_relaxed)); + VERIFY(m_ptr); + return m_ptr; } - ALWAYS_INLINE T* as_nonnull_ptr(FlatPtr bits) const - { - VERIFY(!PtrTraits::is_null(bits)); - return PtrTraits::as_ptr(bits); - } - - mutable Atomic m_bits { PtrTraits::default_null_value }; + T* m_ptr { nullptr }; }; template @@ -496,17 +345,6 @@ inline RefPtr try_make_ref_counted(Args&&... args) return adopt_ref_if_nonnull(new (nothrow) T { forward(args)... }); } -#ifdef KERNEL -template -inline Kernel::KResultOr> adopt_nonnull_ref_or_enomem(T* object) -{ - auto result = adopt_ref_if_nonnull(object); - if (!result) - return ENOMEM; - return result.release_nonnull(); -} -#endif - } using AK::adopt_ref_if_nonnull; @@ -514,6 +352,4 @@ using AK::RefPtr; using AK::static_ptr_cast; using AK::try_make_ref_counted; -#ifdef KERNEL -using AK::adopt_nonnull_ref_or_enomem; #endif diff --git a/AK/WeakPtr.h b/AK/WeakPtr.h index b89dc480f2..bb3e6bcbb2 100644 --- a/AK/WeakPtr.h +++ b/AK/WeakPtr.h @@ -6,7 +6,11 @@ #pragma once -#include +#ifdef KERNEL +# include +#else + +# include namespace AK { @@ -65,21 +69,16 @@ public: } template>::Type* = nullptr> - WeakPtr(const RefPtr& object) + WeakPtr(RefPtr const& object) { - object.do_while_locked([&](U* obj) { - if (obj) - m_link = obj->template make_weak_ptr().take_link(); - }); + if (object) + m_link = object->template make_weak_ptr().take_link(); } template>::Type* = nullptr> - WeakPtr(const NonnullRefPtr& object) + WeakPtr(NonnullRefPtr const& object) { - object.do_while_locked([&](U* obj) { - if (obj) - m_link = obj->template make_weak_ptr().take_link(); - }); + m_link = object->template make_weak_ptr().take_link(); } template>::Type* = nullptr> @@ -102,61 +101,36 @@ public: template>::Type* = nullptr> WeakPtr& operator=(const RefPtr& object) { - object.do_while_locked([&](U* obj) { - if (obj) - m_link = obj->template make_weak_ptr().take_link(); - else - m_link = nullptr; - }); + if (object) + m_link = object->template make_weak_ptr().take_link(); + else + m_link = nullptr; return *this; } template>::Type* = nullptr> WeakPtr& operator=(const NonnullRefPtr& object) { - object.do_while_locked([&](U* obj) { - if (obj) - m_link = obj->template make_weak_ptr().take_link(); - else - m_link = nullptr; - }); + m_link = object->template make_weak_ptr().take_link(); return *this; } [[nodiscard]] RefPtr strong_ref() const { - // This only works with RefCounted objects, but it is the only - // safe way to get a strong reference from a WeakPtr. Any code - // that uses objects not derived from RefCounted will have to - // use unsafe_ptr(), but as the name suggests, it is not safe... - RefPtr ref; - // Using do_while_locked protects against a race with clear()! - m_link.do_while_locked([&](WeakLink* link) { - if (link) - ref = link->template strong_ref(); - }); - return ref; + return RefPtr { ptr() }; } -#ifndef KERNEL - // A lot of user mode code is single-threaded. But for kernel mode code - // this is generally not true as everything is multi-threaded. So make - // these shortcuts and aliases only available to non-kernel code. T* ptr() const { return unsafe_ptr(); } T* operator->() { return unsafe_ptr(); } const T* operator->() const { return unsafe_ptr(); } operator const T*() const { return unsafe_ptr(); } operator T*() { return unsafe_ptr(); } -#endif [[nodiscard]] T* unsafe_ptr() const { - T* ptr = nullptr; - m_link.do_while_locked([&](WeakLink* link) { - if (link) - ptr = link->unsafe_ptr(); - }); - return ptr; + if (m_link) + return m_link->template unsafe_ptr(); + return nullptr; } operator bool() const { return m_link ? !m_link->is_null() : false; } @@ -219,12 +193,7 @@ template struct Formatter> : Formatter { void format(FormatBuilder& builder, const WeakPtr& value) { -#ifdef KERNEL - auto ref = value.strong_ref(); - Formatter::format(builder, ref.ptr()); -#else Formatter::format(builder, value.ptr()); -#endif } }; @@ -240,3 +209,4 @@ WeakPtr try_make_weak_ptr(const T* ptr) } using AK::WeakPtr; +#endif diff --git a/Kernel/Library/ThreadSafeNonnullRefPtr.h b/Kernel/Library/ThreadSafeNonnullRefPtr.h new file mode 100644 index 0000000000..094d8bba44 --- /dev/null +++ b/Kernel/Library/ThreadSafeNonnullRefPtr.h @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2018-2020, Andreas Kling + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include +#include +#include +#ifdef KERNEL +# include +# include +#endif + +namespace AK { + +template +class OwnPtr; +template +class RefPtr; + +template +ALWAYS_INLINE void ref_if_not_null(T* ptr) +{ + if (ptr) + ptr->ref(); +} + +template +ALWAYS_INLINE void unref_if_not_null(T* ptr) +{ + if (ptr) + ptr->unref(); +} + +template +class NonnullRefPtr { + template + friend class RefPtr; + template + friend class NonnullRefPtr; + template + friend class WeakPtr; + +public: + using ElementType = T; + + enum AdoptTag { Adopt }; + + ALWAYS_INLINE NonnullRefPtr(const T& object) + : m_bits((FlatPtr)&object) + { + VERIFY(!(m_bits & 1)); + const_cast(object).ref(); + } + template + ALWAYS_INLINE NonnullRefPtr(const U& object) requires(IsConvertible) + : m_bits((FlatPtr) static_cast(&object)) + { + VERIFY(!(m_bits & 1)); + const_cast(static_cast(object)).ref(); + } + ALWAYS_INLINE NonnullRefPtr(AdoptTag, T& object) + : m_bits((FlatPtr)&object) + { + VERIFY(!(m_bits & 1)); + } + ALWAYS_INLINE NonnullRefPtr(NonnullRefPtr&& other) + : m_bits((FlatPtr)&other.leak_ref()) + { + VERIFY(!(m_bits & 1)); + } + template + ALWAYS_INLINE NonnullRefPtr(NonnullRefPtr&& other) requires(IsConvertible) + : m_bits((FlatPtr)&other.leak_ref()) + { + VERIFY(!(m_bits & 1)); + } + ALWAYS_INLINE NonnullRefPtr(const NonnullRefPtr& other) + : m_bits((FlatPtr)other.add_ref()) + { + VERIFY(!(m_bits & 1)); + } + template + ALWAYS_INLINE NonnullRefPtr(const NonnullRefPtr& other) requires(IsConvertible) + : m_bits((FlatPtr)other.add_ref()) + { + VERIFY(!(m_bits & 1)); + } + ALWAYS_INLINE ~NonnullRefPtr() + { + assign(nullptr); +#ifdef SANITIZE_PTRS + m_bits.store(explode_byte(0xb0), AK::MemoryOrder::memory_order_relaxed); +#endif + } + + template + NonnullRefPtr(const OwnPtr&) = delete; + template + NonnullRefPtr& operator=(const OwnPtr&) = delete; + + template + NonnullRefPtr(const RefPtr&) = delete; + template + NonnullRefPtr& operator=(const RefPtr&) = delete; + NonnullRefPtr(const RefPtr&) = delete; + NonnullRefPtr& operator=(const RefPtr&) = delete; + + NonnullRefPtr& operator=(const NonnullRefPtr& other) + { + if (this != &other) + assign(other.add_ref()); + return *this; + } + + template + NonnullRefPtr& operator=(const NonnullRefPtr& other) requires(IsConvertible) + { + assign(other.add_ref()); + return *this; + } + + ALWAYS_INLINE NonnullRefPtr& operator=(NonnullRefPtr&& other) + { + if (this != &other) + assign(&other.leak_ref()); + return *this; + } + + template + NonnullRefPtr& operator=(NonnullRefPtr&& other) requires(IsConvertible) + { + assign(&other.leak_ref()); + return *this; + } + + NonnullRefPtr& operator=(const T& object) + { + const_cast(object).ref(); + assign(const_cast(&object)); + return *this; + } + + [[nodiscard]] ALWAYS_INLINE T& leak_ref() + { + T* ptr = exchange(nullptr); + VERIFY(ptr); + return *ptr; + } + + ALWAYS_INLINE RETURNS_NONNULL T* ptr() + { + return as_nonnull_ptr(); + } + ALWAYS_INLINE RETURNS_NONNULL const T* ptr() const + { + return as_nonnull_ptr(); + } + + ALWAYS_INLINE RETURNS_NONNULL T* operator->() + { + return as_nonnull_ptr(); + } + ALWAYS_INLINE RETURNS_NONNULL const T* operator->() const + { + return as_nonnull_ptr(); + } + + ALWAYS_INLINE T& operator*() + { + return *as_nonnull_ptr(); + } + ALWAYS_INLINE const T& operator*() const + { + return *as_nonnull_ptr(); + } + + ALWAYS_INLINE RETURNS_NONNULL operator T*() + { + return as_nonnull_ptr(); + } + ALWAYS_INLINE RETURNS_NONNULL operator const T*() const + { + return as_nonnull_ptr(); + } + + ALWAYS_INLINE operator T&() + { + return *as_nonnull_ptr(); + } + ALWAYS_INLINE operator const T&() const + { + return *as_nonnull_ptr(); + } + + operator bool() const = delete; + bool operator!() const = delete; + + void swap(NonnullRefPtr& other) + { + if (this == &other) + return; + + // NOTE: swap is not atomic! + T* other_ptr = other.exchange(nullptr); + T* ptr = exchange(other_ptr); + other.exchange(ptr); + } + + template + void swap(NonnullRefPtr& other) requires(IsConvertible) + { + // NOTE: swap is not atomic! + U* other_ptr = other.exchange(nullptr); + T* ptr = exchange(other_ptr); + other.exchange(ptr); + } + +private: + NonnullRefPtr() = delete; + + ALWAYS_INLINE T* as_ptr() const + { + return (T*)(m_bits.load(AK::MemoryOrder::memory_order_relaxed) & ~(FlatPtr)1); + } + + ALWAYS_INLINE RETURNS_NONNULL T* as_nonnull_ptr() const + { + T* ptr = (T*)(m_bits.load(AK::MemoryOrder::memory_order_relaxed) & ~(FlatPtr)1); + VERIFY(ptr); + return ptr; + } + + template + void do_while_locked(F f) const + { +#ifdef KERNEL + // We don't want to be pre-empted while we have the lock bit set + Kernel::ScopedCritical critical; +#endif + FlatPtr bits; + for (;;) { + bits = m_bits.fetch_or(1, AK::MemoryOrder::memory_order_acq_rel); + if (!(bits & 1)) + break; +#ifdef KERNEL + Kernel::Processor::wait_check(); +#endif + } + VERIFY(!(bits & 1)); + f((T*)bits); + m_bits.store(bits, AK::MemoryOrder::memory_order_release); + } + + ALWAYS_INLINE void assign(T* new_ptr) + { + T* prev_ptr = exchange(new_ptr); + unref_if_not_null(prev_ptr); + } + + ALWAYS_INLINE T* exchange(T* new_ptr) + { + VERIFY(!((FlatPtr)new_ptr & 1)); +#ifdef KERNEL + // We don't want to be pre-empted while we have the lock bit set + Kernel::ScopedCritical critical; +#endif + // Only exchange while not locked + FlatPtr expected = m_bits.load(AK::MemoryOrder::memory_order_relaxed); + for (;;) { + expected &= ~(FlatPtr)1; // only if lock bit is not set + if (m_bits.compare_exchange_strong(expected, (FlatPtr)new_ptr, AK::MemoryOrder::memory_order_acq_rel)) + break; +#ifdef KERNEL + Kernel::Processor::wait_check(); +#endif + } + VERIFY(!(expected & 1)); + return (T*)expected; + } + + T* add_ref() const + { +#ifdef KERNEL + // We don't want to be pre-empted while we have the lock bit set + Kernel::ScopedCritical critical; +#endif + // Lock the pointer + FlatPtr expected = m_bits.load(AK::MemoryOrder::memory_order_relaxed); + for (;;) { + expected &= ~(FlatPtr)1; // only if lock bit is not set + if (m_bits.compare_exchange_strong(expected, expected | 1, AK::MemoryOrder::memory_order_acq_rel)) + break; +#ifdef KERNEL + Kernel::Processor::wait_check(); +#endif + } + + // Add a reference now that we locked the pointer + ref_if_not_null((T*)expected); + + // Unlock the pointer again + m_bits.store(expected, AK::MemoryOrder::memory_order_release); + return (T*)expected; + } + + mutable Atomic m_bits { 0 }; +}; + +template +inline NonnullRefPtr adopt_ref(T& object) +{ + return NonnullRefPtr(NonnullRefPtr::Adopt, object); +} + +template +struct Formatter> : Formatter { + void format(FormatBuilder& builder, const NonnullRefPtr& value) + { + Formatter::format(builder, value.ptr()); + } +}; + +template +inline void swap(NonnullRefPtr& a, NonnullRefPtr& b) requires(IsConvertible) +{ + a.swap(b); +} + +template +requires(IsConstructible) inline NonnullRefPtr make_ref_counted(Args&&... args) +{ + return NonnullRefPtr(NonnullRefPtr::Adopt, *new T(forward(args)...)); +} + +// FIXME: Remove once P0960R3 is available in Clang. +template +inline NonnullRefPtr make_ref_counted(Args&&... args) +{ + return NonnullRefPtr(NonnullRefPtr::Adopt, *new T { forward(args)... }); +} +} + +template +struct Traits> : public GenericTraits> { + using PeekType = T*; + using ConstPeekType = const T*; + static unsigned hash(const NonnullRefPtr& p) { return ptr_hash(p.ptr()); } + static bool equals(const NonnullRefPtr& a, const NonnullRefPtr& b) { return a.ptr() == b.ptr(); } +}; + +using AK::adopt_ref; +using AK::make_ref_counted; +using AK::NonnullRefPtr; diff --git a/Kernel/Library/ThreadSafeRefCounted.h b/Kernel/Library/ThreadSafeRefCounted.h new file mode 100644 index 0000000000..802dd8574a --- /dev/null +++ b/Kernel/Library/ThreadSafeRefCounted.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2018-2020, Andreas Kling + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace AK { + +template +constexpr auto call_will_be_destroyed_if_present(const T* object) -> decltype(const_cast(object)->will_be_destroyed(), TrueType {}) +{ + const_cast(object)->will_be_destroyed(); + return {}; +} + +constexpr auto call_will_be_destroyed_if_present(...) -> FalseType +{ + return {}; +} + +template +constexpr auto call_one_ref_left_if_present(const T* object) -> decltype(const_cast(object)->one_ref_left(), TrueType {}) +{ + const_cast(object)->one_ref_left(); + return {}; +} + +constexpr auto call_one_ref_left_if_present(...) -> FalseType +{ + return {}; +} + +class RefCountedBase { + AK_MAKE_NONCOPYABLE(RefCountedBase); + AK_MAKE_NONMOVABLE(RefCountedBase); + +public: + using RefCountType = unsigned int; + using AllowOwnPtr = FalseType; + + void ref() const + { + auto old_ref_count = m_ref_count.fetch_add(1, AK::MemoryOrder::memory_order_relaxed); + VERIFY(old_ref_count > 0); + VERIFY(!Checked::addition_would_overflow(old_ref_count, 1)); + } + + [[nodiscard]] bool try_ref() const + { + RefCountType expected = m_ref_count.load(AK::MemoryOrder::memory_order_relaxed); + for (;;) { + if (expected == 0) + return false; + VERIFY(!Checked::addition_would_overflow(expected, 1)); + if (m_ref_count.compare_exchange_strong(expected, expected + 1, AK::MemoryOrder::memory_order_acquire)) + return true; + } + } + + [[nodiscard]] RefCountType ref_count() const + { + return m_ref_count.load(AK::MemoryOrder::memory_order_relaxed); + } + +protected: + RefCountedBase() = default; + ~RefCountedBase() + { + VERIFY(m_ref_count.load(AK::MemoryOrder::memory_order_relaxed) == 0); + } + + RefCountType deref_base() const + { + auto old_ref_count = m_ref_count.fetch_sub(1, AK::MemoryOrder::memory_order_acq_rel); + VERIFY(old_ref_count > 0); + return old_ref_count - 1; + } + + mutable Atomic m_ref_count { 1 }; +}; + +template +class RefCounted : public RefCountedBase { +public: + bool unref() const + { + auto new_ref_count = deref_base(); + if (new_ref_count == 0) { + call_will_be_destroyed_if_present(static_cast(this)); + delete static_cast(this); + return true; + } else if (new_ref_count == 1) { + call_one_ref_left_if_present(static_cast(this)); + } + return false; + } +}; + +} + +using AK::RefCounted; +using AK::RefCountedBase; diff --git a/Kernel/Library/ThreadSafeRefPtr.h b/Kernel/Library/ThreadSafeRefPtr.h new file mode 100644 index 0000000000..e8fa9fc5ec --- /dev/null +++ b/Kernel/Library/ThreadSafeRefPtr.h @@ -0,0 +1,519 @@ +/* + * Copyright (c) 2018-2020, Andreas Kling + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#ifdef KERNEL +# include +# include +# include +#endif + +namespace AK { + +template +class OwnPtr; + +template +struct RefPtrTraits { + ALWAYS_INLINE static T* as_ptr(FlatPtr bits) + { + return (T*)(bits & ~(FlatPtr)1); + } + + ALWAYS_INLINE static FlatPtr as_bits(T* ptr) + { + VERIFY(!((FlatPtr)ptr & 1)); + return (FlatPtr)ptr; + } + + template + ALWAYS_INLINE static FlatPtr convert_from(FlatPtr bits) + { + if (PtrTraits::is_null(bits)) + return default_null_value; + return as_bits(PtrTraits::as_ptr(bits)); + } + + ALWAYS_INLINE static bool is_null(FlatPtr bits) + { + return !(bits & ~(FlatPtr)1); + } + + ALWAYS_INLINE static FlatPtr exchange(Atomic& atomic_var, FlatPtr new_value) + { + // Only exchange when lock is not held + VERIFY(!(new_value & 1)); + FlatPtr expected = atomic_var.load(AK::MemoryOrder::memory_order_relaxed); + for (;;) { + expected &= ~(FlatPtr)1; // only if lock bit is not set + if (atomic_var.compare_exchange_strong(expected, new_value, AK::MemoryOrder::memory_order_acq_rel)) + break; +#ifdef KERNEL + Kernel::Processor::wait_check(); +#endif + } + return expected; + } + + ALWAYS_INLINE static bool exchange_if_null(Atomic& atomic_var, FlatPtr new_value) + { + // Only exchange when lock is not held + VERIFY(!(new_value & 1)); + for (;;) { + FlatPtr expected = default_null_value; // only if lock bit is not set + if (atomic_var.compare_exchange_strong(expected, new_value, AK::MemoryOrder::memory_order_acq_rel)) + break; + if (!is_null(expected)) + return false; +#ifdef KERNEL + Kernel::Processor::wait_check(); +#endif + } + return true; + } + + ALWAYS_INLINE static FlatPtr lock(Atomic& atomic_var) + { + // This sets the lock bit atomically, preventing further modifications. + // This is important when e.g. copying a RefPtr where the source + // might be released and freed too quickly. This allows us + // to temporarily lock the pointer so we can add a reference, then + // unlock it + FlatPtr bits; + for (;;) { + bits = atomic_var.fetch_or(1, AK::MemoryOrder::memory_order_acq_rel); + if (!(bits & 1)) + break; +#ifdef KERNEL + Kernel::Processor::wait_check(); +#endif + } + VERIFY(!(bits & 1)); + return bits; + } + + ALWAYS_INLINE static void unlock(Atomic& atomic_var, FlatPtr new_value) + { + VERIFY(!(new_value & 1)); + atomic_var.store(new_value, AK::MemoryOrder::memory_order_release); + } + + static constexpr FlatPtr default_null_value = 0; + + using NullType = std::nullptr_t; +}; + +template +class RefPtr { + template + friend class RefPtr; + template + friend class WeakPtr; + +public: + enum AdoptTag { + Adopt + }; + + RefPtr() = default; + RefPtr(const T* ptr) + : m_bits(PtrTraits::as_bits(const_cast(ptr))) + { + ref_if_not_null(const_cast(ptr)); + } + RefPtr(const T& object) + : m_bits(PtrTraits::as_bits(const_cast(&object))) + { + T* ptr = const_cast(&object); + VERIFY(ptr); + VERIFY(!is_null()); + ptr->ref(); + } + RefPtr(AdoptTag, T& object) + : m_bits(PtrTraits::as_bits(&object)) + { + VERIFY(!is_null()); + } + RefPtr(RefPtr&& other) + : m_bits(other.leak_ref_raw()) + { + } + ALWAYS_INLINE RefPtr(const NonnullRefPtr& other) + : m_bits(PtrTraits::as_bits(const_cast(other.add_ref()))) + { + } + template + ALWAYS_INLINE RefPtr(const NonnullRefPtr& other) requires(IsConvertible) + : m_bits(PtrTraits::as_bits(const_cast(other.add_ref()))) + { + } + template + ALWAYS_INLINE RefPtr(NonnullRefPtr&& other) requires(IsConvertible) + : m_bits(PtrTraits::as_bits(&other.leak_ref())) + { + VERIFY(!is_null()); + } + template> + RefPtr(RefPtr&& other) requires(IsConvertible) + : m_bits(PtrTraits::template convert_from(other.leak_ref_raw())) + { + } + RefPtr(const RefPtr& other) + : m_bits(other.add_ref_raw()) + { + } + template> + RefPtr(const RefPtr& other) requires(IsConvertible) + : m_bits(other.add_ref_raw()) + { + } + ALWAYS_INLINE ~RefPtr() + { + clear(); +#ifdef SANITIZE_PTRS + m_bits.store(explode_byte(0xe0), AK::MemoryOrder::memory_order_relaxed); +#endif + } + + template + RefPtr(const OwnPtr&) = delete; + template + RefPtr& operator=(const OwnPtr&) = delete; + + void swap(RefPtr& other) + { + if (this == &other) + return; + + // NOTE: swap is not atomic! + FlatPtr other_bits = PtrTraits::exchange(other.m_bits, PtrTraits::default_null_value); + FlatPtr bits = PtrTraits::exchange(m_bits, other_bits); + PtrTraits::exchange(other.m_bits, bits); + } + + template> + void swap(RefPtr& other) requires(IsConvertible) + { + // NOTE: swap is not atomic! + FlatPtr other_bits = P::exchange(other.m_bits, P::default_null_value); + FlatPtr bits = PtrTraits::exchange(m_bits, PtrTraits::template convert_from(other_bits)); + P::exchange(other.m_bits, P::template convert_from(bits)); + } + + ALWAYS_INLINE RefPtr& operator=(RefPtr&& other) + { + if (this != &other) + assign_raw(other.leak_ref_raw()); + return *this; + } + + template> + ALWAYS_INLINE RefPtr& operator=(RefPtr&& other) requires(IsConvertible) + { + assign_raw(PtrTraits::template convert_from(other.leak_ref_raw())); + return *this; + } + + template + ALWAYS_INLINE RefPtr& operator=(NonnullRefPtr&& other) requires(IsConvertible) + { + assign_raw(PtrTraits::as_bits(&other.leak_ref())); + return *this; + } + + ALWAYS_INLINE RefPtr& operator=(const NonnullRefPtr& other) + { + assign_raw(PtrTraits::as_bits(other.add_ref())); + return *this; + } + + template + ALWAYS_INLINE RefPtr& operator=(const NonnullRefPtr& other) requires(IsConvertible) + { + assign_raw(PtrTraits::as_bits(other.add_ref())); + return *this; + } + + ALWAYS_INLINE RefPtr& operator=(const RefPtr& other) + { + if (this != &other) + assign_raw(other.add_ref_raw()); + return *this; + } + + template + ALWAYS_INLINE RefPtr& operator=(const RefPtr& other) requires(IsConvertible) + { + assign_raw(other.add_ref_raw()); + return *this; + } + + ALWAYS_INLINE RefPtr& operator=(const T* ptr) + { + ref_if_not_null(const_cast(ptr)); + assign_raw(PtrTraits::as_bits(const_cast(ptr))); + return *this; + } + + ALWAYS_INLINE RefPtr& operator=(const T& object) + { + const_cast(object).ref(); + assign_raw(PtrTraits::as_bits(const_cast(&object))); + return *this; + } + + RefPtr& operator=(std::nullptr_t) + { + clear(); + return *this; + } + + ALWAYS_INLINE bool assign_if_null(RefPtr&& other) + { + if (this == &other) + return is_null(); + return PtrTraits::exchange_if_null(m_bits, other.leak_ref_raw()); + } + + template> + ALWAYS_INLINE bool assign_if_null(RefPtr&& other) + { + if (this == &other) + return is_null(); + return PtrTraits::exchange_if_null(m_bits, PtrTraits::template convert_from(other.leak_ref_raw())); + } + + ALWAYS_INLINE void clear() + { + assign_raw(PtrTraits::default_null_value); + } + + bool operator!() const { return PtrTraits::is_null(m_bits.load(AK::MemoryOrder::memory_order_relaxed)); } + + [[nodiscard]] T* leak_ref() + { + FlatPtr bits = PtrTraits::exchange(m_bits, PtrTraits::default_null_value); + return PtrTraits::as_ptr(bits); + } + + NonnullRefPtr release_nonnull() + { + FlatPtr bits = PtrTraits::exchange(m_bits, PtrTraits::default_null_value); + VERIFY(!PtrTraits::is_null(bits)); + return NonnullRefPtr(NonnullRefPtr::Adopt, *PtrTraits::as_ptr(bits)); + } + + ALWAYS_INLINE T* ptr() { return as_ptr(); } + ALWAYS_INLINE const T* ptr() const { return as_ptr(); } + + ALWAYS_INLINE T* operator->() + { + return as_nonnull_ptr(); + } + + ALWAYS_INLINE const T* operator->() const + { + return as_nonnull_ptr(); + } + + ALWAYS_INLINE T& operator*() + { + return *as_nonnull_ptr(); + } + + ALWAYS_INLINE const T& operator*() const + { + return *as_nonnull_ptr(); + } + + ALWAYS_INLINE operator const T*() const { return as_ptr(); } + ALWAYS_INLINE operator T*() { return as_ptr(); } + + ALWAYS_INLINE operator bool() { return !is_null(); } + + bool operator==(std::nullptr_t) const { return is_null(); } + bool operator!=(std::nullptr_t) const { return !is_null(); } + + bool operator==(const RefPtr& other) const { return as_ptr() == other.as_ptr(); } + bool operator!=(const RefPtr& other) const { return as_ptr() != other.as_ptr(); } + + bool operator==(RefPtr& other) { return as_ptr() == other.as_ptr(); } + bool operator!=(RefPtr& other) { return as_ptr() != other.as_ptr(); } + + bool operator==(const T* other) const { return as_ptr() == other; } + bool operator!=(const T* other) const { return as_ptr() != other; } + + bool operator==(T* other) { return as_ptr() == other; } + bool operator!=(T* other) { return as_ptr() != other; } + + ALWAYS_INLINE bool is_null() const { return PtrTraits::is_null(m_bits.load(AK::MemoryOrder::memory_order_relaxed)); } + + template && !IsNullPointer>::Type* = nullptr> + typename PtrTraits::NullType null_value() const + { + // make sure we are holding a null value + FlatPtr bits = m_bits.load(AK::MemoryOrder::memory_order_relaxed); + VERIFY(PtrTraits::is_null(bits)); + return PtrTraits::to_null_value(bits); + } + template && !IsNullPointer>::Type* = nullptr> + void set_null_value(typename PtrTraits::NullType value) + { + // make sure that new null value would be interpreted as a null value + FlatPtr bits = PtrTraits::from_null_value(value); + VERIFY(PtrTraits::is_null(bits)); + assign_raw(bits); + } + +private: + template + void do_while_locked(F f) const + { +#ifdef KERNEL + // We don't want to be pre-empted while we have the lock bit set + Kernel::ScopedCritical critical; +#endif + FlatPtr bits = PtrTraits::lock(m_bits); + T* ptr = PtrTraits::as_ptr(bits); + f(ptr); + PtrTraits::unlock(m_bits, bits); + } + + [[nodiscard]] ALWAYS_INLINE FlatPtr leak_ref_raw() + { + return PtrTraits::exchange(m_bits, PtrTraits::default_null_value); + } + + [[nodiscard]] ALWAYS_INLINE FlatPtr add_ref_raw() const + { +#ifdef KERNEL + // We don't want to be pre-empted while we have the lock bit set + Kernel::ScopedCritical critical; +#endif + // This prevents a race condition between thread A and B: + // 1. Thread A copies RefPtr, e.g. through assignment or copy constructor, + // gets the pointer from source, but is pre-empted before adding + // another reference + // 2. Thread B calls clear, leak_ref, or release_nonnull on source, and + // then drops the last reference, causing the object to be deleted + // 3. Thread A finishes step #1 by attempting to add a reference to + // the object that was already deleted in step #2 + FlatPtr bits = PtrTraits::lock(m_bits); + if (T* ptr = PtrTraits::as_ptr(bits)) + ptr->ref(); + PtrTraits::unlock(m_bits, bits); + return bits; + } + + ALWAYS_INLINE void assign_raw(FlatPtr bits) + { + FlatPtr prev_bits = PtrTraits::exchange(m_bits, bits); + unref_if_not_null(PtrTraits::as_ptr(prev_bits)); + } + + ALWAYS_INLINE T* as_ptr() const + { + return PtrTraits::as_ptr(m_bits.load(AK::MemoryOrder::memory_order_relaxed)); + } + + ALWAYS_INLINE T* as_nonnull_ptr() const + { + return as_nonnull_ptr(m_bits.load(AK::MemoryOrder::memory_order_relaxed)); + } + + ALWAYS_INLINE T* as_nonnull_ptr(FlatPtr bits) const + { + VERIFY(!PtrTraits::is_null(bits)); + return PtrTraits::as_ptr(bits); + } + + mutable Atomic m_bits { PtrTraits::default_null_value }; +}; + +template +struct Formatter> : Formatter { + void format(FormatBuilder& builder, const RefPtr& value) + { + Formatter::format(builder, value.ptr()); + } +}; + +template +struct Traits> : public GenericTraits> { + using PeekType = T*; + using ConstPeekType = const T*; + static unsigned hash(const RefPtr& p) { return ptr_hash(p.ptr()); } + static bool equals(const RefPtr& a, const RefPtr& b) { return a.ptr() == b.ptr(); } +}; + +template +inline NonnullRefPtr static_ptr_cast(const NonnullRefPtr& ptr) +{ + return NonnullRefPtr(static_cast(*ptr)); +} + +template> +inline RefPtr static_ptr_cast(const RefPtr& ptr) +{ + return RefPtr(static_cast(ptr.ptr())); +} + +template +inline void swap(RefPtr& a, RefPtr& b) requires(IsConvertible) +{ + a.swap(b); +} + +template +inline RefPtr adopt_ref_if_nonnull(T* object) +{ + if (object) + return RefPtr(RefPtr::Adopt, *object); + return {}; +} + +template +requires(IsConstructible) inline RefPtr try_make_ref_counted(Args&&... args) +{ + return adopt_ref_if_nonnull(new (nothrow) T(forward(args)...)); +} + +// FIXME: Remove once P0960R3 is available in Clang. +template +inline RefPtr try_make_ref_counted(Args&&... args) +{ + return adopt_ref_if_nonnull(new (nothrow) T { forward(args)... }); +} + +#ifdef KERNEL +template +inline Kernel::KResultOr> adopt_nonnull_ref_or_enomem(T* object) +{ + auto result = adopt_ref_if_nonnull(object); + if (!result) + return ENOMEM; + return result.release_nonnull(); +} +#endif + +} + +using AK::adopt_ref_if_nonnull; +using AK::RefPtr; +using AK::static_ptr_cast; +using AK::try_make_ref_counted; + +#ifdef KERNEL +using AK::adopt_nonnull_ref_or_enomem; +#endif diff --git a/Kernel/Library/ThreadSafeWeakPtr.h b/Kernel/Library/ThreadSafeWeakPtr.h new file mode 100644 index 0000000000..b89dc480f2 --- /dev/null +++ b/Kernel/Library/ThreadSafeWeakPtr.h @@ -0,0 +1,242 @@ +/* + * Copyright (c) 2018-2020, Andreas Kling + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include + +namespace AK { + +template +class WeakPtr { + template + friend class Weakable; + +public: + WeakPtr() = default; + + template>::Type* = nullptr> + WeakPtr(const WeakPtr& other) + : m_link(other.m_link) + { + } + + template>::Type* = nullptr> + WeakPtr(WeakPtr&& other) + : m_link(other.take_link()) + { + } + + template>::Type* = nullptr> + WeakPtr& operator=(WeakPtr&& other) + { + m_link = other.take_link(); + return *this; + } + + template>::Type* = nullptr> + WeakPtr& operator=(const WeakPtr& other) + { + if ((const void*)this != (const void*)&other) + m_link = other.m_link; + return *this; + } + + WeakPtr& operator=(std::nullptr_t) + { + clear(); + return *this; + } + + template>::Type* = nullptr> + WeakPtr(const U& object) + : m_link(object.template make_weak_ptr().take_link()) + { + } + + template>::Type* = nullptr> + WeakPtr(const U* object) + { + if (object) + m_link = object->template make_weak_ptr().take_link(); + } + + template>::Type* = nullptr> + WeakPtr(const RefPtr& object) + { + object.do_while_locked([&](U* obj) { + if (obj) + m_link = obj->template make_weak_ptr().take_link(); + }); + } + + template>::Type* = nullptr> + WeakPtr(const NonnullRefPtr& object) + { + object.do_while_locked([&](U* obj) { + if (obj) + m_link = obj->template make_weak_ptr().take_link(); + }); + } + + template>::Type* = nullptr> + WeakPtr& operator=(const U& object) + { + m_link = object.template make_weak_ptr().take_link(); + return *this; + } + + template>::Type* = nullptr> + WeakPtr& operator=(const U* object) + { + if (object) + m_link = object->template make_weak_ptr().take_link(); + else + m_link = nullptr; + return *this; + } + + template>::Type* = nullptr> + WeakPtr& operator=(const RefPtr& object) + { + object.do_while_locked([&](U* obj) { + if (obj) + m_link = obj->template make_weak_ptr().take_link(); + else + m_link = nullptr; + }); + return *this; + } + + template>::Type* = nullptr> + WeakPtr& operator=(const NonnullRefPtr& object) + { + object.do_while_locked([&](U* obj) { + if (obj) + m_link = obj->template make_weak_ptr().take_link(); + else + m_link = nullptr; + }); + return *this; + } + + [[nodiscard]] RefPtr strong_ref() const + { + // This only works with RefCounted objects, but it is the only + // safe way to get a strong reference from a WeakPtr. Any code + // that uses objects not derived from RefCounted will have to + // use unsafe_ptr(), but as the name suggests, it is not safe... + RefPtr ref; + // Using do_while_locked protects against a race with clear()! + m_link.do_while_locked([&](WeakLink* link) { + if (link) + ref = link->template strong_ref(); + }); + return ref; + } + +#ifndef KERNEL + // A lot of user mode code is single-threaded. But for kernel mode code + // this is generally not true as everything is multi-threaded. So make + // these shortcuts and aliases only available to non-kernel code. + T* ptr() const { return unsafe_ptr(); } + T* operator->() { return unsafe_ptr(); } + const T* operator->() const { return unsafe_ptr(); } + operator const T*() const { return unsafe_ptr(); } + operator T*() { return unsafe_ptr(); } +#endif + + [[nodiscard]] T* unsafe_ptr() const + { + T* ptr = nullptr; + m_link.do_while_locked([&](WeakLink* link) { + if (link) + ptr = link->unsafe_ptr(); + }); + return ptr; + } + + operator bool() const { return m_link ? !m_link->is_null() : false; } + + [[nodiscard]] bool is_null() const { return !m_link || m_link->is_null(); } + void clear() { m_link = nullptr; } + + [[nodiscard]] RefPtr take_link() { return move(m_link); } + +private: + WeakPtr(const RefPtr& link) + : m_link(link) + { + } + + RefPtr m_link; +}; + +template +template +inline WeakPtr Weakable::make_weak_ptr() const +{ + if constexpr (IsBaseOf) { + // Checking m_being_destroyed isn't sufficient when dealing with + // a RefCounted type.The reference count will drop to 0 before the + // destructor is invoked and revoke_weak_ptrs is called. So, try + // to add a ref (which should fail if the ref count is at 0) so + // that we prevent the destructor and revoke_weak_ptrs from being + // triggered until we're done. + if (!static_cast(this)->try_ref()) + return {}; + } else { + // For non-RefCounted types this means a weak reference can be + // obtained until the ~Weakable destructor is invoked! + if (m_being_destroyed.load(AK::MemoryOrder::memory_order_acquire)) + return {}; + } + if (!m_link) { + // There is a small chance that we create a new WeakLink and throw + // it away because another thread beat us to it. But the window is + // pretty small and the overhead isn't terrible. + m_link.assign_if_null(adopt_ref(*new WeakLink(const_cast(static_cast(*this))))); + } + + WeakPtr weak_ptr(m_link); + + if constexpr (IsBaseOf) { + // Now drop the reference we temporarily added + if (static_cast(this)->unref()) { + // We just dropped the last reference, which should have called + // revoke_weak_ptrs, which should have invalidated our weak_ptr + VERIFY(!weak_ptr.strong_ref()); + return {}; + } + } + return weak_ptr; +} + +template +struct Formatter> : Formatter { + void format(FormatBuilder& builder, const WeakPtr& value) + { +#ifdef KERNEL + auto ref = value.strong_ref(); + Formatter::format(builder, ref.ptr()); +#else + Formatter::format(builder, value.ptr()); +#endif + } +}; + +template +WeakPtr try_make_weak_ptr(const T* ptr) +{ + if (ptr) { + return ptr->template make_weak_ptr(); + } + return {}; +} + +} + +using AK::WeakPtr;