diff --git a/AK/Variant.h b/AK/Variant.h index bc0610510b..1bde4ad4b1 100644 --- a/AK/Variant.h +++ b/AK/Variant.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace AK::Detail { @@ -68,24 +69,6 @@ struct Variant { else Variant::copy_(old_id, old_data, new_data); } - - template - static void visit_(IndexType id, void* data, Visitor&& visitor) - { - if (id == current_index) - visitor(*bit_cast(data)); - else - Variant::visit_(id, data, forward(visitor)); - } - - template - static void visit_(IndexType id, const void* data, Visitor&& visitor) - { - if (id == current_index) - visitor(*bit_cast(data)); - else - Variant::visit_(id, data, forward(visitor)); - } }; template @@ -93,10 +76,23 @@ struct Variant { static void delete_(IndexType, void*) { } static void move_(IndexType, void*, void*) { } static void copy_(IndexType, const void*, void*) { } - template - static void visit_(IndexType, void*, Visitor&&) { } - template - static void visit_(IndexType, const void*, Visitor&&) { } +}; + +template +struct VisitImpl { + template + static constexpr inline decltype(auto) visit(IndexType id, const void* data, Visitor&& visitor) requires(CurrentIndex < sizeof...(Ts)) + { + using T = typename TypeList::template Type; + + if (id == CurrentIndex) + return visitor(*bit_cast(data)); + + if constexpr ((CurrentIndex + 1) < sizeof...(Ts)) + return visit(id, data, forward(visitor)); + else + VERIFY_NOT_REACHED(); + } }; struct VariantNoClearTag { @@ -310,17 +306,17 @@ public: } template - void visit(Fs&&... functions) + decltype(auto) visit(Fs&&... functions) { Visitor visitor { forward(functions)... }; - Helper::visit_(m_index, m_data, visitor); + return VisitHelper::visit(m_index, m_data, move(visitor)); } template - void visit(Fs&&... functions) const + decltype(auto) visit(Fs&&... functions) const { Visitor visitor { forward(functions)... }; - Helper::visit_(m_index, m_data, visitor); + return VisitHelper::visit(m_index, m_data, move(visitor)); } template @@ -357,6 +353,7 @@ private: static constexpr auto data_size = integer_sequence_generate_array(0, IntegerSequence()).max(); static constexpr auto data_alignment = integer_sequence_generate_array(0, IntegerSequence()).max(); using Helper = Detail::Variant; + using VisitHelper = Detail::VisitImpl; explicit Variant(IndexType index, Detail::VariantConstructTag) : Detail::MergeAndDeduplicatePacks>...>() @@ -367,7 +364,7 @@ private: template struct Visitor : Fs... { Visitor(Fs&&... args) - : Fs(args)... + : Fs(forward(args))... { } diff --git a/Tests/AK/TestVariant.cpp b/Tests/AK/TestVariant.cpp index 587eae2819..28dc92340c 100644 --- a/Tests/AK/TestVariant.cpp +++ b/Tests/AK/TestVariant.cpp @@ -6,8 +6,16 @@ #include +#include #include +namespace { + +struct Object : public RefCounted { +}; + +} + TEST_CASE(basic) { Variant the_value { 42 }; @@ -117,3 +125,50 @@ TEST_CASE(duplicated_types) EXPECT(its_just_an_int.has()); EXPECT_EQ(its_just_an_int.get(), 42); } + +TEST_CASE(return_values) +{ + using MyVariant = Variant; + { + MyVariant the_value { 42.0f }; + + float value = the_value.visit( + [&](const int&) { return 1.0f; }, + [&](const String&) { return 2.0f; }, + [&](const float& f) { return f; }); + EXPECT_EQ(value, 42.0f); + } + { + MyVariant the_value { 42 }; + + int value = the_value.visit( + [&](int& i) { return i; }, + [&](String&) { return 2; }, + [&](float&) { return 3; }); + EXPECT_EQ(value, 42); + } + { + const MyVariant the_value { "str" }; + + String value = the_value.visit( + [&](const int&) { return String { "wrong" }; }, + [&](const String& s) { return s; }, + [&](const float&) { return String { "wrong" }; }); + EXPECT_EQ(value, "str"); + } +} + +TEST_CASE(return_values_by_reference) +{ + auto ref = adopt_ref_if_nonnull(new Object()); + Variant the_value { 42.0f }; + + auto& value = the_value.visit( + [&](const int&) -> RefPtr& { return ref; }, + [&](const String&) -> RefPtr& { return ref; }, + [&](const float&) -> RefPtr& { return ref; }); + + EXPECT_EQ(ref, value); + EXPECT_EQ(ref->ref_count(), 1u); + EXPECT_EQ(value->ref_count(), 1u); +}