diff --git a/AK/Retainable.h b/AK/Retainable.h index e8be166af3..1d22428e90 100644 --- a/AK/Retainable.h +++ b/AK/Retainable.h @@ -7,8 +7,6 @@ namespace AK { template class Retainable { public: - Retainable() { } - void retain() { ASSERT(m_retainCount); @@ -28,6 +26,7 @@ public: } protected: + Retainable() { } ~Retainable() { ASSERT(!m_retainCount); diff --git a/AK/WeakPtr.h b/AK/WeakPtr.h new file mode 100644 index 0000000000..eab6c09d31 --- /dev/null +++ b/AK/WeakPtr.h @@ -0,0 +1,37 @@ +#pragma once + +#include "Weakable.h" + +namespace AK { + +template +class WeakPtr { + friend class Weakable; +public: + WeakPtr() { } + WeakPtr(std::nullptr_t) { } + + operator bool() const { return ptr(); } + + T* ptr() { return m_link ? m_link->ptr() : nullptr; } + const T* ptr() const { return m_link ? m_link->ptr() : nullptr; } + bool isNull() const { return !m_link || !m_link->ptr(); } + +private: + WeakPtr(RetainPtr>&& link) : m_link(std::move(link)) { } + + RetainPtr> m_link; +}; + +template +inline WeakPtr Weakable::makeWeakPtr() +{ + if (!m_link) + m_link = adopt(*new WeakLink(*this)); + return WeakPtr(m_link.copyRef()); +} + +} + +using AK::WeakPtr; + diff --git a/AK/Weakable.h b/AK/Weakable.h new file mode 100644 index 0000000000..bd24704961 --- /dev/null +++ b/AK/Weakable.h @@ -0,0 +1,45 @@ +#pragma once + +#include "Assertions.h" +#include "Retainable.h" + +namespace AK { + +template class Weakable; +template class WeakPtr; + +template +class WeakLink : public Retainable> { + friend class Weakable; +public: + T* ptr() { return static_cast(m_ptr); } + const T* ptr() const { return static_cast(m_ptr); } + +private: + explicit WeakLink(Weakable& weakable) : m_ptr(&weakable) { } + Weakable* m_ptr; +}; + +template +class Weakable { +private: + class Link; +public: + WeakPtr makeWeakPtr(); + +protected: + Weakable() { } + + ~Weakable() + { + if (m_link) + m_link->m_ptr = nullptr; + } + +private: + RetainPtr> m_link; +}; + +} + +using AK::Weakable; diff --git a/AK/test.cpp b/AK/test.cpp index 8840c86e5a..1f049aae8d 100644 --- a/AK/test.cpp +++ b/AK/test.cpp @@ -7,6 +7,10 @@ #include "HashMap.h" #include "TemporaryFile.h" #include "Buffer.h" +#include "Weakable.h" +#include "WeakPtr.h" + +static void testWeakPtr(); int main(int, char**) { @@ -183,5 +187,27 @@ int main(int, char**) printInts(h); } + testWeakPtr(); + return 0; } + +class TestWeakable : public Weakable { +public: + TestWeakable() { } + ~TestWeakable() { } +}; + +void testWeakPtr() +{ + auto* weakable = new TestWeakable; + + auto weakPtr = weakable->makeWeakPtr(); + ASSERT(weakPtr); + ASSERT(weakPtr.ptr() == weakable); + + delete weakable; + + ASSERT(!weakPtr); + ASSERT(weakPtr.ptr() == nullptr); +}