From 3c1ef744f6c55b40c3d89baaa942aac7e0c9110e Mon Sep 17 00:00:00 2001 From: Tom Date: Wed, 23 Sep 2020 10:17:43 -0600 Subject: [PATCH] AK: Add RefPtrTraits to allow implementing custom null pointers This adds the ability to implement custom null states that allow storing state in null pointers. --- AK/Forward.h | 3 + AK/NonnullOwnPtr.h | 10 +-- AK/NonnullRefPtr.h | 2 +- AK/RefPtr.h | 182 +++++++++++++++++++++++++++++---------------- AK/StdLibExtras.h | 5 ++ 5 files changed, 131 insertions(+), 71 deletions(-) diff --git a/AK/Forward.h b/AK/Forward.h index cbf3819934..df7c1bbcae 100644 --- a/AK/Forward.h +++ b/AK/Forward.h @@ -113,6 +113,9 @@ template class Optional; template +class RefPtrTraits; + +template> class RefPtr; template diff --git a/AK/NonnullOwnPtr.h b/AK/NonnullOwnPtr.h index 7bf5def3b5..3d430e167d 100644 --- a/AK/NonnullOwnPtr.h +++ b/AK/NonnullOwnPtr.h @@ -35,7 +35,7 @@ namespace AK { -template +template class RefPtr; template class NonnullRefPtr; @@ -85,14 +85,14 @@ public: template NonnullOwnPtr& operator=(const NonnullOwnPtr&) = delete; - template - NonnullOwnPtr(const RefPtr&) = delete; + template> + NonnullOwnPtr(const RefPtr&) = delete; template NonnullOwnPtr(const NonnullRefPtr&) = delete; template NonnullOwnPtr(const WeakPtr&) = delete; - template - NonnullOwnPtr& operator=(const RefPtr&) = delete; + template> + NonnullOwnPtr& operator=(const RefPtr&) = delete; template NonnullOwnPtr& operator=(const NonnullRefPtr&) = delete; template diff --git a/AK/NonnullRefPtr.h b/AK/NonnullRefPtr.h index 22c61b2f86..fab7f227d2 100644 --- a/AK/NonnullRefPtr.h +++ b/AK/NonnullRefPtr.h @@ -35,7 +35,7 @@ namespace AK { template class OwnPtr; -template +template class RefPtr; template diff --git a/AK/RefPtr.h b/AK/RefPtr.h index 98ba2ff352..2ce4ed8161 100644 --- a/AK/RefPtr.h +++ b/AK/RefPtr.h @@ -38,7 +38,31 @@ template class OwnPtr; template +struct RefPtrTraits { + static T* as_ptr(FlatPtr bits) + { + return (T*)bits; + } + + static FlatPtr as_bits(T* ptr) + { + return (FlatPtr)ptr; + } + + static bool is_null(FlatPtr bits) + { + return !bits; + } + + static constexpr FlatPtr default_null_value = 0; + + typedef std::nullptr_t NullType; +}; + +template class RefPtr { + template + friend class RefPtr; public: enum AdoptTag { Adopt @@ -46,66 +70,71 @@ public: RefPtr() { } RefPtr(const T* ptr) - : m_ptr(const_cast(ptr)) + : m_bits(PtrTraits::as_bits(const_cast(ptr))) { - ref_if_not_null(m_ptr); + ref_if_not_null(const_cast(ptr)); } RefPtr(const T& object) - : m_ptr(const_cast(&object)) + : m_bits(PtrTraits::as_bits(const_cast(&object))) { - m_ptr->ref(); + T* ptr = const_cast(&object); + ASSERT(ptr); + ASSERT(!ptr == PtrTraits::is_null(m_bits)); + ptr->ref(); } RefPtr(AdoptTag, T& object) - : m_ptr(&object) + : m_bits(PtrTraits::as_bits(&object)) { + ASSERT(&object); + ASSERT(!PtrTraits::is_null(m_bits)); } RefPtr(RefPtr&& other) - : m_ptr(other.leak_ref()) + : m_bits(other.leak_ref_raw()) { } ALWAYS_INLINE RefPtr(const NonnullRefPtr& other) - : m_ptr(const_cast(other.ptr())) + : m_bits(PtrTraits::as_bits(const_cast(other.ptr()))) { - ASSERT(m_ptr); - m_ptr->ref(); + ASSERT(!PtrTraits::is_null(m_bits)); + PtrTraits::as_ptr(m_bits)->ref(); } template ALWAYS_INLINE RefPtr(const NonnullRefPtr& other) - : m_ptr(const_cast(other.ptr())) + : m_bits(PtrTraits::as_bits(const_cast(other.ptr()))) { - ASSERT(m_ptr); - m_ptr->ref(); + ASSERT(!PtrTraits::is_null(m_bits)); + PtrTraits::as_ptr(m_bits)->ref(); } template ALWAYS_INLINE RefPtr(NonnullRefPtr&& other) - : m_ptr(&other.leak_ref()) + : m_bits(PtrTraits::as_bits(&other.leak_ref())) { - ASSERT(m_ptr); + ASSERT(!PtrTraits::is_null(m_bits)); } - template - RefPtr(RefPtr&& other) - : m_ptr(other.leak_ref()) + template> + RefPtr(RefPtr&& other) + : m_bits(other.leak_ref_raw()) { } RefPtr(const RefPtr& other) - : m_ptr(const_cast(other.ptr())) + : m_bits(PtrTraits::as_bits(const_cast(other.ptr()))) { - ref_if_not_null(m_ptr); + ref_if_not_null(const_cast(other.ptr())); } - template - RefPtr(const RefPtr& other) - : m_ptr(const_cast(other.ptr())) + template> + RefPtr(const RefPtr& other) + : m_bits(PtrTraits::as_bits(const_cast(other.ptr()))) { - ref_if_not_null(m_ptr); + ref_if_not_null(const_cast(other.ptr())); } ALWAYS_INLINE ~RefPtr() { clear(); #ifdef SANITIZE_PTRS if constexpr (sizeof(T*) == 8) - m_ptr = (T*)(0xe0e0e0e0e0e0e0e0); + m_bits = 0xe0e0e0e0e0e0e0e0; else - m_ptr = (T*)(0xe0e0e0e0); + m_bits = 0xe0e0e0e0; #endif } RefPtr(std::nullptr_t) { } @@ -116,9 +145,9 @@ public: RefPtr& operator=(const OwnPtr&) = delete; template - void swap(RefPtr& other) + void swap(RefPtr& other) { - ::swap(m_ptr, other.m_ptr); + ::swap(m_bits, other.m_bits); } ALWAYS_INLINE RefPtr& operator=(RefPtr&& other) @@ -129,7 +158,7 @@ public: } template - ALWAYS_INLINE RefPtr& operator=(RefPtr&& other) + ALWAYS_INLINE RefPtr& operator=(RefPtr&& other) { RefPtr tmp = move(other); swap(tmp); @@ -141,7 +170,7 @@ public: { RefPtr tmp = move(other); swap(tmp); - ASSERT(m_ptr); + ASSERT(!PtrTraits::is_null(m_bits)); return *this; } @@ -149,7 +178,7 @@ public: { RefPtr tmp = other; swap(tmp); - ASSERT(m_ptr); + ASSERT(!PtrTraits::is_null(m_bits)); return *this; } @@ -158,7 +187,7 @@ public: { RefPtr tmp = other; swap(tmp); - ASSERT(m_ptr); + ASSERT(!PtrTraits::is_null(m_bits)); return *this; } @@ -199,78 +228,101 @@ public: ALWAYS_INLINE void clear() { - unref_if_not_null(m_ptr); - m_ptr = nullptr; + unref_if_not_null(PtrTraits::as_ptr(m_bits)); + m_bits = PtrTraits::default_null_value; } - bool operator!() const { return !m_ptr; } + bool operator!() const { return PtrTraits::is_null(m_bits); } [[nodiscard]] T* leak_ref() { - return exchange(m_ptr, nullptr); + FlatPtr bits = exchange(m_bits, PtrTraits::default_null_value); + return !PtrTraits::is_null(bits) ? PtrTraits::as_ptr(bits) : nullptr; } NonnullRefPtr release_nonnull() { - ASSERT(m_ptr); + ASSERT(!PtrTraits::is_null(m_bits)); return NonnullRefPtr(NonnullRefPtr::Adopt, *leak_ref()); } - ALWAYS_INLINE T* ptr() { return m_ptr; } - ALWAYS_INLINE const T* ptr() const { return m_ptr; } + ALWAYS_INLINE T* ptr() { return !PtrTraits::is_null(m_bits) ? PtrTraits::as_ptr(m_bits) : nullptr; } + ALWAYS_INLINE const T* ptr() const { return !PtrTraits::is_null(m_bits) ? PtrTraits::as_ptr(m_bits) : nullptr; } ALWAYS_INLINE T* operator->() { - ASSERT(m_ptr); - return m_ptr; + ASSERT(!PtrTraits::is_null(m_bits)); + return PtrTraits::as_ptr(m_bits); } ALWAYS_INLINE const T* operator->() const { - ASSERT(m_ptr); - return m_ptr; + ASSERT(!PtrTraits::is_null(m_bits)); + return PtrTraits::as_ptr(m_bits); } ALWAYS_INLINE T& operator*() { - ASSERT(m_ptr); - return *m_ptr; + ASSERT(!PtrTraits::is_null(m_bits)); + return *PtrTraits::as_ptr(m_bits); } ALWAYS_INLINE const T& operator*() const { - ASSERT(m_ptr); - return *m_ptr; + ASSERT(!PtrTraits::is_null(m_bits)); + return *PtrTraits::as_ptr(m_bits); } - ALWAYS_INLINE operator const T*() const { return m_ptr; } - ALWAYS_INLINE operator T*() { return m_ptr; } + ALWAYS_INLINE operator const T*() const { return PtrTraits::as_ptr(m_bits); } + ALWAYS_INLINE operator T*() { return PtrTraits::as_ptr(m_bits); } - operator bool() { return !!m_ptr; } + operator bool() { return !PtrTraits::is_null(m_bits); } - bool operator==(std::nullptr_t) const { return !m_ptr; } - bool operator!=(std::nullptr_t) const { return m_ptr; } + bool operator==(std::nullptr_t) const { return PtrTraits::is_null(m_bits); } + bool operator!=(std::nullptr_t) const { return !PtrTraits::is_null(m_bits); } - bool operator==(const RefPtr& other) const { return m_ptr == other.m_ptr; } - bool operator!=(const RefPtr& other) const { return m_ptr != other.m_ptr; } + bool operator==(const RefPtr& other) const { return m_bits == other.m_bits; } + bool operator!=(const RefPtr& other) const { return m_bits != other.m_bits; } - bool operator==(RefPtr& other) { return m_ptr == other.m_ptr; } - bool operator!=(RefPtr& other) { return m_ptr != other.m_ptr; } + bool operator==(RefPtr& other) { return m_bits == other.m_bits; } + bool operator!=(RefPtr& other) { return m_bits != other.m_bits; } - bool operator==(const T* other) const { return m_ptr == other; } - bool operator!=(const T* other) const { return m_ptr != other; } + bool operator==(const T* other) const { return PtrTraits::as_ptr(m_bits) == other; } + bool operator!=(const T* other) const { return PtrTraits::as_ptr(m_bits) != other; } - bool operator==(T* other) { return m_ptr == other; } - bool operator!=(T* other) { return m_ptr != other; } + bool operator==(T* other) { return PtrTraits::as_ptr(m_bits) == other; } + bool operator!=(T* other) { return PtrTraits::as_ptr(m_bits) != other; } - bool is_null() const { return !m_ptr; } + bool is_null() const { return PtrTraits::is_null(m_bits); } + + template::value && !IsNullPointer::value>::Type* = nullptr> + typename PtrTraits::NullType null_value() const + { + // make sure we are holding a null value + ASSERT(PtrTraits::is_null(m_bits)); + return PtrTraits::to_null_value(m_bits); + } + template::value && !IsNullPointer::value>::Type* = nullptr> + void set_null_value(typename PtrTraits::NullType value) + { + // make sure that new null value would be interpreted as a null value + FlatPtr bits = PtrTraits::from_null_value(value); + ASSERT(PtrTraits::is_null(bits)); + clear(); + m_bits = bits; + } private: - T* m_ptr = nullptr; + [[nodiscard]] FlatPtr leak_ref_raw() + { + return exchange(m_bits, PtrTraits::default_null_value); + } + + FlatPtr m_bits { PtrTraits::default_null_value }; }; -template -inline const LogStream& operator<<(const LogStream& stream, const RefPtr& value) +template> +inline const LogStream& operator<<(const LogStream& stream, const RefPtr& value) { return stream << value.ptr(); } @@ -288,10 +340,10 @@ inline NonnullRefPtr static_ptr_cast(const NonnullRefPtr& ptr) return NonnullRefPtr(static_cast(*ptr)); } -template +template> inline RefPtr static_ptr_cast(const RefPtr& ptr) { - return RefPtr(static_cast(ptr.ptr())); + return RefPtr(static_cast(ptr.ptr())); } } diff --git a/AK/StdLibExtras.h b/AK/StdLibExtras.h index 17bd4acfea..0462ecd699 100644 --- a/AK/StdLibExtras.h +++ b/AK/StdLibExtras.h @@ -298,6 +298,11 @@ struct Conditional { typedef FalseType Type; }; +template +struct IsNullPointer : IsSame::Type> { +}; + + template struct RemoveReference { typedef T Type;