diff --git a/Tests/LibTLS/TestTLSHandshake.cpp b/Tests/LibTLS/TestTLSHandshake.cpp index 761098c8f9..9494f258e8 100644 --- a/Tests/LibTLS/TestTLSHandshake.cpp +++ b/Tests/LibTLS/TestTLSHandshake.cpp @@ -70,10 +70,11 @@ TEST_CASE(test_TLS_hello_handshake) tls->set_root_certificates(s_root_ca_certificates); bool sent_request = false; ByteBuffer contents; - tls->on_tls_ready_to_write = [&](TLS::TLSv12& tls) { + tls->set_on_tls_ready_to_write([&](TLS::TLSv12& tls) { if (sent_request) return; sent_request = true; + Core::deferred_invoke([&tls] { tls.set_on_tls_ready_to_write(nullptr); }); if (!tls.write("GET / HTTP/1.1\r\nHost: "_b)) { FAIL("write(0) failed"); loop.quit(0); @@ -87,7 +88,7 @@ TEST_CASE(test_TLS_hello_handshake) FAIL("write(2) failed"); loop.quit(0); } - }; + }); tls->on_tls_ready_to_read = [&](TLS::TLSv12& tls) { auto data = tls.read(); if (!data.has_value()) { diff --git a/Userland/Libraries/LibGemini/GeminiJob.cpp b/Userland/Libraries/LibGemini/GeminiJob.cpp index 340d470644..64ff2dbee2 100644 --- a/Userland/Libraries/LibGemini/GeminiJob.cpp +++ b/Userland/Libraries/LibGemini/GeminiJob.cpp @@ -93,9 +93,10 @@ void GeminiJob::register_on_ready_to_read(Function callback) void GeminiJob::register_on_ready_to_write(Function callback) { - m_socket->on_tls_ready_to_write = [callback = move(callback)](auto&) { + m_socket->set_on_tls_ready_to_write([callback = move(callback)](auto& tls) { + Core::deferred_invoke([&tls] { tls.set_on_tls_ready_to_write(nullptr); }); callback(); - }; + }); } bool GeminiJob::can_read_line() const diff --git a/Userland/Libraries/LibHTTP/HttpsJob.cpp b/Userland/Libraries/LibHTTP/HttpsJob.cpp index 0f249fc5fb..abaeb662f6 100644 --- a/Userland/Libraries/LibHTTP/HttpsJob.cpp +++ b/Userland/Libraries/LibHTTP/HttpsJob.cpp @@ -68,6 +68,7 @@ void HttpsJob::shutdown() return; m_socket->on_tls_ready_to_read = nullptr; m_socket->on_tls_connected = nullptr; + m_socket->set_on_tls_ready_to_write(nullptr); m_socket = nullptr; } @@ -97,9 +98,10 @@ void HttpsJob::register_on_ready_to_read(Function callback) void HttpsJob::register_on_ready_to_write(Function callback) { - m_socket->on_tls_ready_to_write = [callback = move(callback)](auto&) { + m_socket->set_on_tls_ready_to_write([callback = move(callback)](auto& tls) { + Core::deferred_invoke([&tls] { tls.set_on_tls_ready_to_write(nullptr); }); callback(); - }; + }); } bool HttpsJob::can_read_line() const diff --git a/Userland/Libraries/LibTLS/TLSv12.h b/Userland/Libraries/LibTLS/TLSv12.h index e1702c1a1c..1a84f0713e 100644 --- a/Userland/Libraries/LibTLS/TLSv12.h +++ b/Userland/Libraries/LibTLS/TLSv12.h @@ -373,8 +373,16 @@ public: bool can_read() const { return m_context.application_buffer.size() > 0; } String read_line(size_t max_size); + void set_on_tls_ready_to_write(Function function) + { + on_tls_ready_to_write = move(function); + if (on_tls_ready_to_write) { + if (is_established()) + on_tls_ready_to_write(*this); + } + } + Function on_tls_ready_to_read; - Function on_tls_ready_to_write; Function on_tls_error; Function on_tls_connected; Function on_tls_finished; @@ -521,6 +529,7 @@ private: i32 m_max_wait_time_for_handshake_in_seconds { 10 }; RefPtr m_handshake_timeout_timer; + Function on_tls_ready_to_write; }; } diff --git a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp b/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp index 370f77c4eb..b3313ca9e6 100644 --- a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp +++ b/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp @@ -33,9 +33,10 @@ void TLSv12WebSocketConnectionImpl::connect(ConnectionInfo const& connection) m_socket->on_tls_ready_to_read = [this](auto&) { on_ready_to_read(); }; - m_socket->on_tls_ready_to_write = [this](auto&) { + m_socket->set_on_tls_ready_to_write([this](auto& tls) { + tls.set_on_tls_ready_to_write(nullptr); on_connected(); - }; + }); m_socket->on_tls_finished = [this] { on_connection_error(); }; diff --git a/Userland/Utilities/test-crypto.cpp b/Userland/Utilities/test-crypto.cpp index 730ff499aa..1d94219d00 100644 --- a/Userland/Utilities/test-crypto.cpp +++ b/Userland/Utilities/test-crypto.cpp @@ -152,12 +152,12 @@ static void tls(const char* message, size_t len) if (buffer.has_value()) out("{}", StringView { buffer->data(), buffer->size() }); }; - tls->on_tls_ready_to_write = [&](auto&) { + tls->set_on_tls_ready_to_write([&](auto&) { if (write.size()) { tls->write(write); write.clear(); } - }; + }); tls->on_tls_error = [&](auto) { g_loop.quit(1); }; @@ -2013,10 +2013,11 @@ static void tls_test_client_hello() tls->set_root_certificates(s_root_ca_certificates); bool sent_request = false; ByteBuffer contents; - tls->on_tls_ready_to_write = [&](TLS::TLSv12& tls) { + tls->set_on_tls_ready_to_write([&](TLS::TLSv12& tls) { if (sent_request) return; sent_request = true; + Core::deferred_invoke([&tls] { tls.set_on_tls_ready_to_write(nullptr); }); if (!tls.write("GET / HTTP/1.1\r\nHost: "_b)) { FAIL(write(0) failed); loop.quit(0); @@ -2030,7 +2031,7 @@ static void tls_test_client_hello() FAIL(write(2) failed); loop.quit(0); } - }; + }); tls->on_tls_ready_to_read = [&](TLS::TLSv12& tls) { auto data = tls.read(); if (!data.has_value()) {