From 436693c0c907b9a15bd6cccc07304e6332bc7b7b Mon Sep 17 00:00:00 2001 From: Ali Mohammad Pur Date: Sun, 19 Sep 2021 02:31:11 +0430 Subject: [PATCH] LibTLS: Use a setter for on_tls_ready_to_write with some more smarts The callback should be called as soon as the connection is established, and if we actually set the callback when it already is, we expect it to be called immediately. --- Tests/LibTLS/TestTLSHandshake.cpp | 5 +++-- Userland/Libraries/LibGemini/GeminiJob.cpp | 5 +++-- Userland/Libraries/LibHTTP/HttpsJob.cpp | 6 ++++-- Userland/Libraries/LibTLS/TLSv12.h | 11 ++++++++++- .../Impl/TLSv12WebSocketConnectionImpl.cpp | 5 +++-- Userland/Utilities/test-crypto.cpp | 9 +++++---- 6 files changed, 28 insertions(+), 13 deletions(-) diff --git a/Tests/LibTLS/TestTLSHandshake.cpp b/Tests/LibTLS/TestTLSHandshake.cpp index 761098c8f9..9494f258e8 100644 --- a/Tests/LibTLS/TestTLSHandshake.cpp +++ b/Tests/LibTLS/TestTLSHandshake.cpp @@ -70,10 +70,11 @@ TEST_CASE(test_TLS_hello_handshake) tls->set_root_certificates(s_root_ca_certificates); bool sent_request = false; ByteBuffer contents; - tls->on_tls_ready_to_write = [&](TLS::TLSv12& tls) { + tls->set_on_tls_ready_to_write([&](TLS::TLSv12& tls) { if (sent_request) return; sent_request = true; + Core::deferred_invoke([&tls] { tls.set_on_tls_ready_to_write(nullptr); }); if (!tls.write("GET / HTTP/1.1\r\nHost: "_b)) { FAIL("write(0) failed"); loop.quit(0); @@ -87,7 +88,7 @@ TEST_CASE(test_TLS_hello_handshake) FAIL("write(2) failed"); loop.quit(0); } - }; + }); tls->on_tls_ready_to_read = [&](TLS::TLSv12& tls) { auto data = tls.read(); if (!data.has_value()) { diff --git a/Userland/Libraries/LibGemini/GeminiJob.cpp b/Userland/Libraries/LibGemini/GeminiJob.cpp index 340d470644..64ff2dbee2 100644 --- a/Userland/Libraries/LibGemini/GeminiJob.cpp +++ b/Userland/Libraries/LibGemini/GeminiJob.cpp @@ -93,9 +93,10 @@ void GeminiJob::register_on_ready_to_read(Function callback) void GeminiJob::register_on_ready_to_write(Function callback) { - m_socket->on_tls_ready_to_write = [callback = move(callback)](auto&) { + m_socket->set_on_tls_ready_to_write([callback = move(callback)](auto& tls) { + Core::deferred_invoke([&tls] { tls.set_on_tls_ready_to_write(nullptr); }); callback(); - }; + }); } bool GeminiJob::can_read_line() const diff --git a/Userland/Libraries/LibHTTP/HttpsJob.cpp b/Userland/Libraries/LibHTTP/HttpsJob.cpp index 0f249fc5fb..abaeb662f6 100644 --- a/Userland/Libraries/LibHTTP/HttpsJob.cpp +++ b/Userland/Libraries/LibHTTP/HttpsJob.cpp @@ -68,6 +68,7 @@ void HttpsJob::shutdown() return; m_socket->on_tls_ready_to_read = nullptr; m_socket->on_tls_connected = nullptr; + m_socket->set_on_tls_ready_to_write(nullptr); m_socket = nullptr; } @@ -97,9 +98,10 @@ void HttpsJob::register_on_ready_to_read(Function callback) void HttpsJob::register_on_ready_to_write(Function callback) { - m_socket->on_tls_ready_to_write = [callback = move(callback)](auto&) { + m_socket->set_on_tls_ready_to_write([callback = move(callback)](auto& tls) { + Core::deferred_invoke([&tls] { tls.set_on_tls_ready_to_write(nullptr); }); callback(); - }; + }); } bool HttpsJob::can_read_line() const diff --git a/Userland/Libraries/LibTLS/TLSv12.h b/Userland/Libraries/LibTLS/TLSv12.h index e1702c1a1c..1a84f0713e 100644 --- a/Userland/Libraries/LibTLS/TLSv12.h +++ b/Userland/Libraries/LibTLS/TLSv12.h @@ -373,8 +373,16 @@ public: bool can_read() const { return m_context.application_buffer.size() > 0; } String read_line(size_t max_size); + void set_on_tls_ready_to_write(Function function) + { + on_tls_ready_to_write = move(function); + if (on_tls_ready_to_write) { + if (is_established()) + on_tls_ready_to_write(*this); + } + } + Function on_tls_ready_to_read; - Function on_tls_ready_to_write; Function on_tls_error; Function on_tls_connected; Function on_tls_finished; @@ -521,6 +529,7 @@ private: i32 m_max_wait_time_for_handshake_in_seconds { 10 }; RefPtr m_handshake_timeout_timer; + Function on_tls_ready_to_write; }; } diff --git a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp b/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp index 370f77c4eb..b3313ca9e6 100644 --- a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp +++ b/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp @@ -33,9 +33,10 @@ void TLSv12WebSocketConnectionImpl::connect(ConnectionInfo const& connection) m_socket->on_tls_ready_to_read = [this](auto&) { on_ready_to_read(); }; - m_socket->on_tls_ready_to_write = [this](auto&) { + m_socket->set_on_tls_ready_to_write([this](auto& tls) { + tls.set_on_tls_ready_to_write(nullptr); on_connected(); - }; + }); m_socket->on_tls_finished = [this] { on_connection_error(); }; diff --git a/Userland/Utilities/test-crypto.cpp b/Userland/Utilities/test-crypto.cpp index 730ff499aa..1d94219d00 100644 --- a/Userland/Utilities/test-crypto.cpp +++ b/Userland/Utilities/test-crypto.cpp @@ -152,12 +152,12 @@ static void tls(const char* message, size_t len) if (buffer.has_value()) out("{}", StringView { buffer->data(), buffer->size() }); }; - tls->on_tls_ready_to_write = [&](auto&) { + tls->set_on_tls_ready_to_write([&](auto&) { if (write.size()) { tls->write(write); write.clear(); } - }; + }); tls->on_tls_error = [&](auto) { g_loop.quit(1); }; @@ -2013,10 +2013,11 @@ static void tls_test_client_hello() tls->set_root_certificates(s_root_ca_certificates); bool sent_request = false; ByteBuffer contents; - tls->on_tls_ready_to_write = [&](TLS::TLSv12& tls) { + tls->set_on_tls_ready_to_write([&](TLS::TLSv12& tls) { if (sent_request) return; sent_request = true; + Core::deferred_invoke([&tls] { tls.set_on_tls_ready_to_write(nullptr); }); if (!tls.write("GET / HTTP/1.1\r\nHost: "_b)) { FAIL(write(0) failed); loop.quit(0); @@ -2030,7 +2031,7 @@ static void tls_test_client_hello() FAIL(write(2) failed); loop.quit(0); } - }; + }); tls->on_tls_ready_to_read = [&](TLS::TLSv12& tls) { auto data = tls.read(); if (!data.has_value()) {