From fe672989a90225499c73416a757373c1198f7657 Mon Sep 17 00:00:00 2001 From: Zaggy1024 Date: Tue, 11 Jul 2023 20:48:56 -0500 Subject: [PATCH] LibCore: Add a class for thread-safe promises Since the existing Promise class is designed with deferred tasks on the main thread only, we need a new class that will ensure we can handle promises that are resolved/rejected off the main thread. This new class ensures that the callbacks are only called on the same thread that the promise is fulfilled from. If the callbacks are not set before the thread tries to fulfill the promise, it will spin until they are so that they will run on that thread. --- Meta/Lagom/CMakeLists.txt | 1 + Tests/LibCore/CMakeLists.txt | 1 + Tests/LibCore/TestLibCorePromise.cpp | 108 +++++++++++ Userland/Libraries/LibCore/EventLoop.cpp | 19 +- Userland/Libraries/LibCore/EventLoop.h | 1 + Userland/Libraries/LibCore/Forward.h | 2 + Userland/Libraries/LibCore/ThreadedPromise.h | 193 +++++++++++++++++++ 7 files changed, 321 insertions(+), 4 deletions(-) create mode 100644 Userland/Libraries/LibCore/ThreadedPromise.h diff --git a/Meta/Lagom/CMakeLists.txt b/Meta/Lagom/CMakeLists.txt index 4c8725a4cf..3f3f144362 100644 --- a/Meta/Lagom/CMakeLists.txt +++ b/Meta/Lagom/CMakeLists.txt @@ -662,6 +662,7 @@ if (BUILD_LAGOM) # LibCore if ((LINUX OR APPLE) AND NOT EMSCRIPTEN) lagom_test(../../Tests/LibCore/TestLibCoreFileWatcher.cpp) + lagom_test(../../Tests/LibCore/TestLibCorePromise.cpp LIBS LibThreading) endif() # RegexLibC test POSIX and contains many Serenity extensions diff --git a/Tests/LibCore/CMakeLists.txt b/Tests/LibCore/CMakeLists.txt index afbc6a364e..d0455d2c04 100644 --- a/Tests/LibCore/CMakeLists.txt +++ b/Tests/LibCore/CMakeLists.txt @@ -12,6 +12,7 @@ foreach(source IN LISTS TEST_SOURCES) serenity_test("${source}" LibCore) endforeach() +target_link_libraries(TestLibCorePromise PRIVATE LibThreading) # NOTE: Required because of the LocalServer tests target_link_libraries(TestLibCoreStream PRIVATE LibThreading) target_link_libraries(TestLibCoreSharedSingleProducerCircularQueue PRIVATE LibThreading) diff --git a/Tests/LibCore/TestLibCorePromise.cpp b/Tests/LibCore/TestLibCorePromise.cpp index 6e8f969a88..ac691474e2 100644 --- a/Tests/LibCore/TestLibCorePromise.cpp +++ b/Tests/LibCore/TestLibCorePromise.cpp @@ -6,7 +6,10 @@ #include #include +#include #include +#include +#include TEST_CASE(promise_await_async_event) { @@ -57,3 +60,108 @@ TEST_CASE(promise_chain_handlers) EXPECT(resolved); EXPECT(!rejected); } + +TEST_CASE(threaded_promise_instantly_resolved) +{ + Core::EventLoop loop; + + bool resolved = false; + bool rejected = true; + Optional thread_id; + + auto promise = Core::ThreadedPromise::create(); + + auto thread = Threading::Thread::construct([&, promise] { + thread_id = pthread_self(); + promise->resolve(42); + return 0; + }); + thread->start(); + + promise + ->when_resolved([&](int result) { + EXPECT(thread_id.has_value()); + EXPECT(pthread_equal(thread_id.value(), pthread_self())); + resolved = true; + rejected = false; + EXPECT_EQ(result, 42); + }) + .when_rejected([](Error&&) { + VERIFY_NOT_REACHED(); + }); + + promise->await(); + EXPECT(promise->has_completed()); + EXPECT(resolved); + EXPECT(!rejected); + MUST(thread->join()); +} + +TEST_CASE(threaded_promise_resolved_later) +{ + Core::EventLoop loop; + + bool unblock_thread = false; + bool resolved = false; + bool rejected = true; + Optional thread_id; + + auto promise = Core::ThreadedPromise::create(); + + auto thread = Threading::Thread::construct([&, promise] { + thread_id = pthread_self(); + while (!unblock_thread) + usleep(500); + promise->resolve(42); + return 0; + }); + thread->start(); + + promise + ->when_resolved([&]() { + EXPECT(thread_id.has_value()); + EXPECT(pthread_equal(thread_id.value(), pthread_self())); + EXPECT(unblock_thread); + resolved = true; + rejected = false; + }) + .when_rejected([](Error&&) { + VERIFY_NOT_REACHED(); + }); + + Core::EventLoop::current().deferred_invoke([&]() { unblock_thread = true; }); + + promise->await(); + EXPECT(promise->has_completed()); + EXPECT(unblock_thread); + EXPECT(resolved); + EXPECT(!rejected); + MUST(thread->join()); +} + +TEST_CASE(threaded_promise_synchronously_resolved) +{ + Core::EventLoop loop; + + bool resolved = false; + bool rejected = true; + auto thread_id = pthread_self(); + + auto promise = Core::ThreadedPromise::create(); + promise->resolve(1337); + + promise + ->when_resolved([&]() { + EXPECT(pthread_equal(thread_id, pthread_self())); + resolved = true; + rejected = false; + }) + .when_rejected([](Error&&) { + VERIFY_NOT_REACHED(); + }); + + promise->await(); + EXPECT(promise->has_completed()); + EXPECT(resolved); + EXPECT(!rejected); +} diff --git a/Userland/Libraries/LibCore/EventLoop.cpp b/Userland/Libraries/LibCore/EventLoop.cpp index 7d721341ee..ee889a98cf 100644 --- a/Userland/Libraries/LibCore/EventLoop.cpp +++ b/Userland/Libraries/LibCore/EventLoop.cpp @@ -17,12 +17,17 @@ namespace Core { namespace { -Vector& event_loop_stack() +OwnPtr>& event_loop_stack_uninitialized() { thread_local OwnPtr> s_event_loop_stack = nullptr; - if (s_event_loop_stack == nullptr) - s_event_loop_stack = make>(); - return *s_event_loop_stack; + return s_event_loop_stack; +} +Vector& event_loop_stack() +{ + auto& the_stack = event_loop_stack_uninitialized(); + if (the_stack == nullptr) + the_stack = make>(); + return *the_stack; } } @@ -41,6 +46,12 @@ EventLoop::~EventLoop() } } +bool EventLoop::is_running() +{ + auto& stack = event_loop_stack_uninitialized(); + return stack != nullptr && !stack->is_empty(); +} + EventLoop& EventLoop::current() { return event_loop_stack().last(); diff --git a/Userland/Libraries/LibCore/EventLoop.h b/Userland/Libraries/LibCore/EventLoop.h index 4d09340091..e43a22838b 100644 --- a/Userland/Libraries/LibCore/EventLoop.h +++ b/Userland/Libraries/LibCore/EventLoop.h @@ -92,6 +92,7 @@ public: }; static void notify_forked(ForkEvent); + static bool is_running(); static EventLoop& current(); EventLoopImplementation& impl() { return *m_impl; } diff --git a/Userland/Libraries/LibCore/Forward.h b/Userland/Libraries/LibCore/Forward.h index 9272a542d7..f7c36ebd04 100644 --- a/Userland/Libraries/LibCore/Forward.h +++ b/Userland/Libraries/LibCore/Forward.h @@ -36,6 +36,8 @@ class ProcessStatisticsReader; class Socket; template class Promise; +template +class ThreadedPromise; class SocketAddress; class TCPServer; class TCPSocket; diff --git a/Userland/Libraries/LibCore/ThreadedPromise.h b/Userland/Libraries/LibCore/ThreadedPromise.h new file mode 100644 index 0000000000..fe2dc2a26f --- /dev/null +++ b/Userland/Libraries/LibCore/ThreadedPromise.h @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2021, Kyle Pereira + * Copyright (c) 2022, kleines Filmröllchen + * Copyright (c) 2021-2023, Ali Mohammad Pur + * Copyright (c) 2023, Gregory Bertilson + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace Core { + +template +class ThreadedPromise + : public AtomicRefCounted> { +public: + static NonnullRefPtr> create() + { + return adopt_ref(*new ThreadedPromise()); + } + + using ResultType = Conditional, Empty, TResult>; + using ErrorType = TError; + + void resolve(ResultType&& result) + { + when_error_handler_is_ready([self = NonnullRefPtr(*this), result = move(result)]() mutable { + if (self->m_resolution_handler) { + auto handler_result = self->m_resolution_handler(forward(result)); + if (handler_result.is_error()) + self->m_rejection_handler(handler_result.release_error()); + self->m_has_completed = true; + } + }); + } + void resolve() + requires IsSame + { + resolve(Empty()); + } + + void reject(ErrorType&& error) + { + when_error_handler_is_ready([this, error = move(error)]() mutable { + m_rejection_handler(forward(error)); + m_has_completed = true; + }); + } + void reject(ErrorType const& error) + requires IsTriviallyCopyable + { + reject(ErrorType(error)); + } + + bool has_completed() + { + Threading::MutexLocker locker { m_mutex }; + return m_has_completed; + } + + void await() + { + while (!has_completed()) + Core::EventLoop::current().pump(EventLoop::WaitMode::PollForEvents); + } + + // Set the callback to be called when the promise is resolved. A rejection callback + // must also be provided before any callback will be called. + template, ResultType&&> ResolvedHandler> + ThreadedPromise& when_resolved(ResolvedHandler handler) + { + Threading::MutexLocker locker { m_mutex }; + VERIFY(!m_resolution_handler); + m_resolution_handler = move(handler); + return *this; + } + + template ResolvedHandler> + ThreadedPromise& when_resolved(ResolvedHandler handler) + { + return when_resolved([handler = move(handler)](ResultType&& result) -> ErrorOr { + handler(forward(result)); + return {}; + }); + } + + template> ResolvedHandler> + ThreadedPromise& when_resolved(ResolvedHandler handler) + { + return when_resolved([handler = move(handler)](ResultType&&) -> ErrorOr { + return handler(); + }); + } + + template ResolvedHandler> + ThreadedPromise& when_resolved(ResolvedHandler handler) + { + return when_resolved([handler = move(handler)](ResultType&&) -> ErrorOr { + handler(); + return {}; + }); + } + + // Set the callback to be called when the promise is rejected. Setting this callback + // will cause the promise fulfillment to be ready to be handled. + template RejectedHandler> + ThreadedPromise& when_rejected(RejectedHandler when_rejected = [](ErrorType&) {}) + { + Threading::MutexLocker locker { m_mutex }; + VERIFY(!m_rejection_handler); + m_rejection_handler = move(when_rejected); + return *this; + } + + template>, ResultType&&> ChainedResolution> + NonnullRefPtr> chain_promise(ChainedResolution chained_resolution) + { + auto new_promise = ThreadedPromise::create(); + when_resolved([=, chained_resolution = move(chained_resolution)](ResultType&& result) mutable -> ErrorOr { + chained_resolution(forward(result)) + ->when_resolved([=](auto&& new_result) { new_promise->resolve(move(new_result)); }) + .when_rejected([=](ErrorType&& error) { new_promise->reject(move(error)); }); + return {}; + }); + when_rejected([=](ErrorType&& error) { new_promise->reject(move(error)); }); + return new_promise; + } + + template, ResultType&&> MappingFunction> + NonnullRefPtr> map(MappingFunction mapping_function) + { + auto new_promise = ThreadedPromise::create(); + when_resolved([=, mapping_function = move(mapping_function)](ResultType&& result) -> ErrorOr { + new_promise->resolve(TRY(mapping_function(forward(result)))); + return {}; + }); + when_rejected([=](ErrorType&& error) { new_promise->reject(move(error)); }); + return new_promise; + } + +private: + template + static void deferred_handler_check(NonnullRefPtr self, F&& function) + { + Threading::MutexLocker locker { self->m_mutex }; + if (self->m_rejection_handler) { + function(); + return; + } + EventLoop::current().deferred_invoke([self, function = forward(function)]() mutable { + deferred_handler_check(self, move(function)); + }); + } + + template + void when_error_handler_is_ready(F function) + { + if (EventLoop::is_running()) { + deferred_handler_check(NonnullRefPtr(*this), move(function)); + } else { + // NOTE: Handlers should always be set almost immediately, so we can expect this + // to spin extremely briefly. Therefore, sleeping the thread should not be + // necessary. + while (true) { + Threading::MutexLocker locker { m_mutex }; + if (m_rejection_handler) + break; + } + VERIFY(m_rejection_handler); + function(); + } + } + + ThreadedPromise() = default; + ThreadedPromise(Object* parent) + : Object(parent) + { + } + + Function(ResultType&&)> m_resolution_handler; + Function m_rejection_handler; + Threading::Mutex m_mutex; + bool m_has_completed; +}; + +}