diff --git a/Userland/Libraries/LibCore/NetworkJob.cpp b/Userland/Libraries/LibCore/NetworkJob.cpp index c79f3085b4..5742955281 100644 --- a/Userland/Libraries/LibCore/NetworkJob.cpp +++ b/Userland/Libraries/LibCore/NetworkJob.cpp @@ -23,7 +23,7 @@ void NetworkJob::start(NonnullRefPtr) { } -void NetworkJob::shutdown() +void NetworkJob::shutdown(ShutdownMode) { } @@ -40,7 +40,7 @@ void NetworkJob::did_finish(NonnullRefPtr&& response) dbgln_if(CNETWORKJOB_DEBUG, "{} job did_finish", *this); VERIFY(on_finish); on_finish(true); - shutdown(); + shutdown(ShutdownMode::DetachFromSocket); } void NetworkJob::did_fail(Error error) @@ -56,7 +56,7 @@ void NetworkJob::did_fail(Error error) dbgln_if(CNETWORKJOB_DEBUG, "{}{{{:p}}} job did_fail! error: {} ({})", class_name(), this, (unsigned)error, to_string(error)); VERIFY(on_finish); on_finish(false); - shutdown(); + shutdown(ShutdownMode::DetachFromSocket); } void NetworkJob::did_progress(Optional total_size, u32 downloaded) diff --git a/Userland/Libraries/LibCore/NetworkJob.h b/Userland/Libraries/LibCore/NetworkJob.h index 781dc825a8..e2eda0b407 100644 --- a/Userland/Libraries/LibCore/NetworkJob.h +++ b/Userland/Libraries/LibCore/NetworkJob.h @@ -35,12 +35,16 @@ public: NetworkResponse* response() { return m_response.ptr(); } const NetworkResponse* response() const { return m_response.ptr(); } + enum class ShutdownMode { + DetachFromSocket, + CloseSocket, + }; virtual void start(NonnullRefPtr) = 0; - virtual void shutdown() = 0; + virtual void shutdown(ShutdownMode) = 0; void cancel() { - shutdown(); + shutdown(ShutdownMode::DetachFromSocket); m_error = Error::Cancelled; } diff --git a/Userland/Libraries/LibCore/Socket.cpp b/Userland/Libraries/LibCore/Socket.cpp index 11dae78527..b21c487041 100644 --- a/Userland/Libraries/LibCore/Socket.cpp +++ b/Userland/Libraries/LibCore/Socket.cpp @@ -208,6 +208,16 @@ void Socket::did_update_fd(int fd) } } +bool Socket::close() +{ + m_connected = false; + if (m_notifier) + m_notifier->close(); + if (m_read_notifier) + m_read_notifier->close(); + return IODevice::close(); +} + void Socket::ensure_read_notifier() { VERIFY(m_connected); diff --git a/Userland/Libraries/LibCore/Socket.h b/Userland/Libraries/LibCore/Socket.h index 359e8a56c8..3ccbae20e2 100644 --- a/Userland/Libraries/LibCore/Socket.h +++ b/Userland/Libraries/LibCore/Socket.h @@ -42,6 +42,8 @@ public: SocketAddress destination_address() const { return m_destination_address; } int destination_port() const { return m_destination_port; } + virtual bool close() override; + Function on_connected; Function on_error; Function on_ready_to_read; diff --git a/Userland/Libraries/LibGemini/GeminiJob.cpp b/Userland/Libraries/LibGemini/GeminiJob.cpp index 64ff2dbee2..b459d81afd 100644 --- a/Userland/Libraries/LibGemini/GeminiJob.cpp +++ b/Userland/Libraries/LibGemini/GeminiJob.cpp @@ -58,13 +58,17 @@ void GeminiJob::start(NonnullRefPtr socket) } } -void GeminiJob::shutdown() +void GeminiJob::shutdown(ShutdownMode mode) { if (!m_socket) return; - m_socket->on_tls_ready_to_read = nullptr; - m_socket->on_tls_connected = nullptr; - m_socket = nullptr; + if (mode == ShutdownMode::CloseSocket) { + m_socket->close(); + } else { + m_socket->on_tls_ready_to_read = nullptr; + m_socket->on_tls_connected = nullptr; + m_socket = nullptr; + } } void GeminiJob::read_while_data_available(Function read) diff --git a/Userland/Libraries/LibGemini/GeminiJob.h b/Userland/Libraries/LibGemini/GeminiJob.h index 20cc8afcf9..7cad08473d 100644 --- a/Userland/Libraries/LibGemini/GeminiJob.h +++ b/Userland/Libraries/LibGemini/GeminiJob.h @@ -28,7 +28,7 @@ public: } virtual void start(NonnullRefPtr) override; - virtual void shutdown() override; + virtual void shutdown(ShutdownMode) override; void set_certificate(String certificate, String key); Core::Socket const* socket() const { return m_socket; } diff --git a/Userland/Libraries/LibGemini/Job.h b/Userland/Libraries/LibGemini/Job.h index 5bc50216f7..d8e2167385 100644 --- a/Userland/Libraries/LibGemini/Job.h +++ b/Userland/Libraries/LibGemini/Job.h @@ -20,7 +20,7 @@ public: virtual ~Job() override; virtual void start(NonnullRefPtr) override = 0; - virtual void shutdown() override = 0; + virtual void shutdown(ShutdownMode) override = 0; GeminiResponse* response() { return static_cast(Core::NetworkJob::response()); } const GeminiResponse* response() const { return static_cast(Core::NetworkJob::response()); } diff --git a/Userland/Libraries/LibHTTP/HttpJob.cpp b/Userland/Libraries/LibHTTP/HttpJob.cpp index 58ab0e28d6..6e0115c341 100644 --- a/Userland/Libraries/LibHTTP/HttpJob.cpp +++ b/Userland/Libraries/LibHTTP/HttpJob.cpp @@ -43,13 +43,17 @@ void HttpJob::start(NonnullRefPtr socket) }; } -void HttpJob::shutdown() +void HttpJob::shutdown(ShutdownMode mode) { if (!m_socket) return; - m_socket->on_ready_to_read = nullptr; - m_socket->on_connected = nullptr; - m_socket = nullptr; + if (mode == ShutdownMode::CloseSocket) { + m_socket->close(); + } else { + m_socket->on_ready_to_read = nullptr; + m_socket->on_connected = nullptr; + m_socket = nullptr; + } } void HttpJob::register_on_ready_to_read(Function callback) diff --git a/Userland/Libraries/LibHTTP/HttpJob.h b/Userland/Libraries/LibHTTP/HttpJob.h index 16ccb3abfb..08acf9ef79 100644 --- a/Userland/Libraries/LibHTTP/HttpJob.h +++ b/Userland/Libraries/LibHTTP/HttpJob.h @@ -28,7 +28,7 @@ public: } virtual void start(NonnullRefPtr) override; - virtual void shutdown() override; + virtual void shutdown(ShutdownMode) override; Core::Socket const* socket() const { return m_socket; } URL url() const { return m_request.url(); } diff --git a/Userland/Libraries/LibHTTP/HttpsJob.cpp b/Userland/Libraries/LibHTTP/HttpsJob.cpp index abaeb662f6..cadfbea933 100644 --- a/Userland/Libraries/LibHTTP/HttpsJob.cpp +++ b/Userland/Libraries/LibHTTP/HttpsJob.cpp @@ -62,14 +62,18 @@ void HttpsJob::start(NonnullRefPtr socket) } } -void HttpsJob::shutdown() +void HttpsJob::shutdown(ShutdownMode mode) { if (!m_socket) return; - m_socket->on_tls_ready_to_read = nullptr; - m_socket->on_tls_connected = nullptr; - m_socket->set_on_tls_ready_to_write(nullptr); - m_socket = nullptr; + if (mode == ShutdownMode::CloseSocket) { + m_socket->close(); + } else { + m_socket->on_tls_ready_to_read = nullptr; + m_socket->on_tls_connected = nullptr; + m_socket->set_on_tls_ready_to_write(nullptr); + m_socket = nullptr; + } } void HttpsJob::set_certificate(String certificate, String private_key) diff --git a/Userland/Libraries/LibHTTP/HttpsJob.h b/Userland/Libraries/LibHTTP/HttpsJob.h index c7c20ea7e6..149861727a 100644 --- a/Userland/Libraries/LibHTTP/HttpsJob.h +++ b/Userland/Libraries/LibHTTP/HttpsJob.h @@ -29,7 +29,7 @@ public: } virtual void start(NonnullRefPtr) override; - virtual void shutdown() override; + virtual void shutdown(ShutdownMode) override; void set_certificate(String certificate, String key); Core::Socket const* socket() const { return m_socket; } diff --git a/Userland/Libraries/LibHTTP/Job.cpp b/Userland/Libraries/LibHTTP/Job.cpp index 97c4b317bd..8141a3e869 100644 --- a/Userland/Libraries/LibHTTP/Job.cpp +++ b/Userland/Libraries/LibHTTP/Job.cpp @@ -412,6 +412,10 @@ void Job::finish_up() m_has_scheduled_finish = true; auto response = HttpResponse::create(m_code, move(m_headers)); deferred_invoke([this, response = move(response)] { + // If the server responded with "Connection: close", close the connection + // as the server may or may not want to close the socket. + if (auto result = response->headers().get("Connection"sv); result.has_value() && result.value().equals_ignoring_case("close"sv)) + shutdown(ShutdownMode::CloseSocket); did_finish(response); }); } diff --git a/Userland/Libraries/LibHTTP/Job.h b/Userland/Libraries/LibHTTP/Job.h index 99a6e0607f..7eb85154df 100644 --- a/Userland/Libraries/LibHTTP/Job.h +++ b/Userland/Libraries/LibHTTP/Job.h @@ -22,7 +22,7 @@ public: virtual ~Job() override; virtual void start(NonnullRefPtr) override = 0; - virtual void shutdown() override = 0; + virtual void shutdown(ShutdownMode) override = 0; HttpResponse* response() { return static_cast(Core::NetworkJob::response()); } const HttpResponse* response() const { return static_cast(Core::NetworkJob::response()); }