From 145e246a5e9f229656bc1182bdc6c397ca8efe8a Mon Sep 17 00:00:00 2001 From: Timothy Flynn Date: Wed, 19 May 2021 11:28:27 -0400 Subject: [PATCH] AK: Allow AK::Variant::visit to return a value This changes Variant::visit() to forward the value returned by the selected visitor invocation. By perfectly forwarding the returned value, this allows for the visitor to return by value or reference. Note that all provided visitors must return the same type - the compiler will otherwise fail with the message: "inconsistent deduction for auto return type". --- AK/Variant.h | 51 ++++++++++++++++++------------------- Tests/AK/TestVariant.cpp | 55 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 27 deletions(-) diff --git a/AK/Variant.h b/AK/Variant.h index bc0610510b..1bde4ad4b1 100644 --- a/AK/Variant.h +++ b/AK/Variant.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace AK::Detail { @@ -68,24 +69,6 @@ struct Variant { else Variant::copy_(old_id, old_data, new_data); } - - template - static void visit_(IndexType id, void* data, Visitor&& visitor) - { - if (id == current_index) - visitor(*bit_cast(data)); - else - Variant::visit_(id, data, forward(visitor)); - } - - template - static void visit_(IndexType id, const void* data, Visitor&& visitor) - { - if (id == current_index) - visitor(*bit_cast(data)); - else - Variant::visit_(id, data, forward(visitor)); - } }; template @@ -93,10 +76,23 @@ struct Variant { static void delete_(IndexType, void*) { } static void move_(IndexType, void*, void*) { } static void copy_(IndexType, const void*, void*) { } - template - static void visit_(IndexType, void*, Visitor&&) { } - template - static void visit_(IndexType, const void*, Visitor&&) { } +}; + +template +struct VisitImpl { + template + static constexpr inline decltype(auto) visit(IndexType id, const void* data, Visitor&& visitor) requires(CurrentIndex < sizeof...(Ts)) + { + using T = typename TypeList::template Type; + + if (id == CurrentIndex) + return visitor(*bit_cast(data)); + + if constexpr ((CurrentIndex + 1) < sizeof...(Ts)) + return visit(id, data, forward(visitor)); + else + VERIFY_NOT_REACHED(); + } }; struct VariantNoClearTag { @@ -310,17 +306,17 @@ public: } template - void visit(Fs&&... functions) + decltype(auto) visit(Fs&&... functions) { Visitor visitor { forward(functions)... }; - Helper::visit_(m_index, m_data, visitor); + return VisitHelper::visit(m_index, m_data, move(visitor)); } template - void visit(Fs&&... functions) const + decltype(auto) visit(Fs&&... functions) const { Visitor visitor { forward(functions)... }; - Helper::visit_(m_index, m_data, visitor); + return VisitHelper::visit(m_index, m_data, move(visitor)); } template @@ -357,6 +353,7 @@ private: static constexpr auto data_size = integer_sequence_generate_array(0, IntegerSequence()).max(); static constexpr auto data_alignment = integer_sequence_generate_array(0, IntegerSequence()).max(); using Helper = Detail::Variant; + using VisitHelper = Detail::VisitImpl; explicit Variant(IndexType index, Detail::VariantConstructTag) : Detail::MergeAndDeduplicatePacks>...>() @@ -367,7 +364,7 @@ private: template struct Visitor : Fs... { Visitor(Fs&&... args) - : Fs(args)... + : Fs(forward(args))... { } diff --git a/Tests/AK/TestVariant.cpp b/Tests/AK/TestVariant.cpp index 587eae2819..28dc92340c 100644 --- a/Tests/AK/TestVariant.cpp +++ b/Tests/AK/TestVariant.cpp @@ -6,8 +6,16 @@ #include +#include #include +namespace { + +struct Object : public RefCounted { +}; + +} + TEST_CASE(basic) { Variant the_value { 42 }; @@ -117,3 +125,50 @@ TEST_CASE(duplicated_types) EXPECT(its_just_an_int.has()); EXPECT_EQ(its_just_an_int.get(), 42); } + +TEST_CASE(return_values) +{ + using MyVariant = Variant; + { + MyVariant the_value { 42.0f }; + + float value = the_value.visit( + [&](const int&) { return 1.0f; }, + [&](const String&) { return 2.0f; }, + [&](const float& f) { return f; }); + EXPECT_EQ(value, 42.0f); + } + { + MyVariant the_value { 42 }; + + int value = the_value.visit( + [&](int& i) { return i; }, + [&](String&) { return 2; }, + [&](float&) { return 3; }); + EXPECT_EQ(value, 42); + } + { + const MyVariant the_value { "str" }; + + String value = the_value.visit( + [&](const int&) { return String { "wrong" }; }, + [&](const String& s) { return s; }, + [&](const float&) { return String { "wrong" }; }); + EXPECT_EQ(value, "str"); + } +} + +TEST_CASE(return_values_by_reference) +{ + auto ref = adopt_ref_if_nonnull(new Object()); + Variant the_value { 42.0f }; + + auto& value = the_value.visit( + [&](const int&) -> RefPtr& { return ref; }, + [&](const String&) -> RefPtr& { return ref; }, + [&](const float&) -> RefPtr& { return ref; }); + + EXPECT_EQ(ref, value); + EXPECT_EQ(ref->ref_count(), 1u); + EXPECT_EQ(value->ref_count(), 1u); +}