mirror of
				https://github.com/RGBCube/serenity
				synced 2025-10-31 19:22:45 +00:00 
			
		
		
		
	ProtocolServer: Stream the downloaded data if possible
This patchset makes ProtocolServer stream the downloads to its client
(LibProtocol), and as such changes the download API; a possible
download lifecycle could be as such:
notation = client->server:'>', server->client:'<', pipe activity:'*'
```
> StartDownload(GET, url, headers, {})
< Response(0, fd 8)
* {data, 1024b}
< HeadersBecameAvailable(0, response_headers, 200)
< DownloadProgress(0, 4K, 1024)
* {data, 1024b}
* {data, 1024b}
< DownloadProgress(0, 4K, 2048)
* {data, 1024b}
< DownloadProgress(0, 4K, 1024)
< DownloadFinished(0, true, 4K)
```
Since managing the received file descriptor is a pain, LibProtocol
implements `Download::stream_into(OutputStream)`, which can be used to
stream the download into any given output stream (be it a file, or
memory, or writing stuff with a delay, etc.).
Also, as some of the users of this API require all the downloaded data
upfront, LibProtocol also implements `set_should_buffer_all_input()`,
which causes the download instance to buffer all the data until the
download is complete, and to call the `on_buffered_download_finish`
hook.
			
			
This commit is contained in:
		
							parent
							
								
									36d642ee75
								
							
						
					
					
						commit
						4a2da10e38
					
				
					 55 changed files with 528 additions and 235 deletions
				
			
		|  | @ -29,6 +29,7 @@ | ||||||
| #include <AK/SharedBuffer.h> | #include <AK/SharedBuffer.h> | ||||||
| #include <AK/StringBuilder.h> | #include <AK/StringBuilder.h> | ||||||
| #include <LibCore/File.h> | #include <LibCore/File.h> | ||||||
|  | #include <LibCore/FileStream.h> | ||||||
| #include <LibCore/StandardPaths.h> | #include <LibCore/StandardPaths.h> | ||||||
| #include <LibDesktop/Launcher.h> | #include <LibDesktop/Launcher.h> | ||||||
| #include <LibGUI/BoxLayout.h> | #include <LibGUI/BoxLayout.h> | ||||||
|  | @ -61,9 +62,19 @@ DownloadWidget::DownloadWidget(const URL& url) | ||||||
|     m_download->on_progress = [this](Optional<u32> total_size, u32 downloaded_size) { |     m_download->on_progress = [this](Optional<u32> total_size, u32 downloaded_size) { | ||||||
|         did_progress(total_size.value(), downloaded_size); |         did_progress(total_size.value(), downloaded_size); | ||||||
|     }; |     }; | ||||||
|     m_download->on_finish = [this](bool success, auto payload, auto payload_storage, auto& response_headers, auto) { | 
 | ||||||
|         did_finish(success, payload, payload_storage, response_headers); |     { | ||||||
|     }; |         auto file_or_error = Core::File::open(m_destination_path, Core::IODevice::WriteOnly); | ||||||
|  |         if (file_or_error.is_error()) { | ||||||
|  |             GUI::MessageBox::show(window(), String::formatted("Cannot open {} for writing", m_destination_path), "Download failed", GUI::MessageBox::Type::Error); | ||||||
|  |             window()->close(); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |         m_output_file_stream = make<Core::OutputFileStream>(*file_or_error.value()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     m_download->on_finish = [this](bool success, auto) { did_finish(success); }; | ||||||
|  |     m_download->stream_into(*m_output_file_stream); | ||||||
| 
 | 
 | ||||||
|     set_fill_with_background_color(true); |     set_fill_with_background_color(true); | ||||||
|     auto& layout = set_layout<GUI::VerticalBoxLayout>(); |     auto& layout = set_layout<GUI::VerticalBoxLayout>(); | ||||||
|  | @ -149,7 +160,7 @@ void DownloadWidget::did_progress(Optional<u32> total_size, u32 downloaded_size) | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void DownloadWidget::did_finish(bool success, [[maybe_unused]] ReadonlyBytes payload, [[maybe_unused]] RefPtr<SharedBuffer> payload_storage, [[maybe_unused]] const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers) | void DownloadWidget::did_finish(bool success) | ||||||
| { | { | ||||||
|     dbg() << "did_finish, success=" << success; |     dbg() << "did_finish, success=" << success; | ||||||
| 
 | 
 | ||||||
|  | @ -166,17 +177,6 @@ void DownloadWidget::did_finish(bool success, [[maybe_unused]] ReadonlyBytes pay | ||||||
|         window()->close(); |         window()->close(); | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| 
 |  | ||||||
|     auto file_or_error = Core::File::open(m_destination_path, Core::IODevice::WriteOnly); |  | ||||||
|     if (file_or_error.is_error()) { |  | ||||||
|         GUI::MessageBox::show(window(), String::formatted("Cannot open {} for writing", m_destination_path), "Download failed", GUI::MessageBox::Type::Error); |  | ||||||
|         window()->close(); |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     auto& file = *file_or_error.value(); |  | ||||||
|     bool write_success = file.write(payload.data(), payload.size()); |  | ||||||
|     ASSERT(write_success); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -28,6 +28,7 @@ | ||||||
| 
 | 
 | ||||||
| #include <AK/URL.h> | #include <AK/URL.h> | ||||||
| #include <LibCore/ElapsedTimer.h> | #include <LibCore/ElapsedTimer.h> | ||||||
|  | #include <LibCore/FileStream.h> | ||||||
| #include <LibGUI/ProgressBar.h> | #include <LibGUI/ProgressBar.h> | ||||||
| #include <LibGUI/Widget.h> | #include <LibGUI/Widget.h> | ||||||
| #include <LibProtocol/Download.h> | #include <LibProtocol/Download.h> | ||||||
|  | @ -44,7 +45,7 @@ private: | ||||||
|     explicit DownloadWidget(const URL&); |     explicit DownloadWidget(const URL&); | ||||||
| 
 | 
 | ||||||
|     void did_progress(Optional<u32> total_size, u32 downloaded_size); |     void did_progress(Optional<u32> total_size, u32 downloaded_size); | ||||||
|     void did_finish(bool success, ReadonlyBytes payload, RefPtr<SharedBuffer> payload_storage, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers); |     void did_finish(bool success); | ||||||
| 
 | 
 | ||||||
|     URL m_url; |     URL m_url; | ||||||
|     String m_destination_path; |     String m_destination_path; | ||||||
|  | @ -53,6 +54,7 @@ private: | ||||||
|     RefPtr<GUI::Label> m_progress_label; |     RefPtr<GUI::Label> m_progress_label; | ||||||
|     RefPtr<GUI::Button> m_cancel_button; |     RefPtr<GUI::Button> m_cancel_button; | ||||||
|     RefPtr<GUI::Button> m_close_button; |     RefPtr<GUI::Button> m_close_button; | ||||||
|  |     OwnPtr<Core::OutputFileStream> m_output_file_stream; | ||||||
|     Core::ElapsedTimer m_elapsed_timer; |     Core::ElapsedTimer m_elapsed_timer; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -68,7 +68,7 @@ int main(int argc, char** argv) | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     if (pledge("stdio shared_buffer accept unix cpath rpath wpath fattr", nullptr) < 0) { |     if (pledge("stdio shared_buffer accept unix cpath rpath wpath fattr sendfd recvfd", nullptr) < 0) { | ||||||
|         perror("pledge"); |         perror("pledge"); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
|  | @ -86,7 +86,7 @@ int main(int argc, char** argv) | ||||||
|     Web::ResourceLoader::the(); |     Web::ResourceLoader::the(); | ||||||
| 
 | 
 | ||||||
|     // FIXME: Once there is a standalone Download Manager, we can drop the "unix" pledge.
 |     // FIXME: Once there is a standalone Download Manager, we can drop the "unix" pledge.
 | ||||||
|     if (pledge("stdio shared_buffer accept unix cpath rpath wpath", nullptr) < 0) { |     if (pledge("stdio shared_buffer accept unix cpath rpath wpath sendfd recvfd", nullptr) < 0) { | ||||||
|         perror("pledge"); |         perror("pledge"); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | @ -32,7 +32,8 @@ | ||||||
| 
 | 
 | ||||||
| namespace Core { | namespace Core { | ||||||
| 
 | 
 | ||||||
| NetworkJob::NetworkJob() | NetworkJob::NetworkJob(OutputStream& output_stream) | ||||||
|  |     : m_output_stream(output_stream) | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -27,6 +27,7 @@ | ||||||
| #pragma once | #pragma once | ||||||
| 
 | 
 | ||||||
| #include <AK/Function.h> | #include <AK/Function.h> | ||||||
|  | #include <AK/Stream.h> | ||||||
| #include <LibCore/Object.h> | #include <LibCore/Object.h> | ||||||
| 
 | 
 | ||||||
| namespace Core { | namespace Core { | ||||||
|  | @ -43,6 +44,8 @@ public: | ||||||
|     }; |     }; | ||||||
|     virtual ~NetworkJob() override; |     virtual ~NetworkJob() override; | ||||||
| 
 | 
 | ||||||
|  |     // Could fire twice, after Headers and after Trailers!
 | ||||||
|  |     Function<void(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)> on_headers_received; | ||||||
|     Function<void(bool success)> on_finish; |     Function<void(bool success)> on_finish; | ||||||
|     Function<void(Optional<u32>, u32)> on_progress; |     Function<void(Optional<u32>, u32)> on_progress; | ||||||
| 
 | 
 | ||||||
|  | @ -62,13 +65,16 @@ public: | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| protected: | protected: | ||||||
|     NetworkJob(); |     NetworkJob(OutputStream&); | ||||||
|     void did_finish(NonnullRefPtr<NetworkResponse>&&); |     void did_finish(NonnullRefPtr<NetworkResponse>&&); | ||||||
|     void did_fail(Error); |     void did_fail(Error); | ||||||
|     void did_progress(Optional<u32> total_size, u32 downloaded); |     void did_progress(Optional<u32> total_size, u32 downloaded); | ||||||
| 
 | 
 | ||||||
|  |     size_t do_write(ReadonlyBytes bytes) { return m_output_stream.write(bytes); } | ||||||
|  | 
 | ||||||
| private: | private: | ||||||
|     RefPtr<NetworkResponse> m_response; |     RefPtr<NetworkResponse> m_response; | ||||||
|  |     OutputStream& m_output_stream; | ||||||
|     Error m_error { Error::None }; |     Error m_error { Error::None }; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -28,8 +28,7 @@ | ||||||
| 
 | 
 | ||||||
| namespace Core { | namespace Core { | ||||||
| 
 | 
 | ||||||
| NetworkResponse::NetworkResponse(ByteBuffer&& payload) | NetworkResponse::NetworkResponse() | ||||||
|     : m_payload(payload) |  | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -36,13 +36,11 @@ public: | ||||||
|     virtual ~NetworkResponse(); |     virtual ~NetworkResponse(); | ||||||
| 
 | 
 | ||||||
|     bool is_error() const { return m_error; } |     bool is_error() const { return m_error; } | ||||||
|     const ByteBuffer& payload() const { return m_payload; } |  | ||||||
| 
 | 
 | ||||||
| protected: | protected: | ||||||
|     explicit NetworkResponse(ByteBuffer&&); |     explicit NetworkResponse(); | ||||||
| 
 | 
 | ||||||
|     bool m_error { false }; |     bool m_error { false }; | ||||||
|     ByteBuffer m_payload; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -142,9 +142,9 @@ bool GeminiJob::eof() const | ||||||
|     return m_socket->eof(); |     return m_socket->eof(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| bool GeminiJob::write(const ByteBuffer& data) | bool GeminiJob::write(ReadonlyBytes bytes) | ||||||
| { | { | ||||||
|     return m_socket->write(data); |     return m_socket->write(bytes); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -37,8 +37,8 @@ namespace Gemini { | ||||||
| class GeminiJob final : public Job { | class GeminiJob final : public Job { | ||||||
|     C_OBJECT(GeminiJob) |     C_OBJECT(GeminiJob) | ||||||
| public: | public: | ||||||
|     explicit GeminiJob(const GeminiRequest& request, const Vector<Certificate>* override_certificates = nullptr) |     explicit GeminiJob(const GeminiRequest& request, OutputStream& output_stream, const Vector<Certificate>* override_certificates = nullptr) | ||||||
|         : Job(request) |         : Job(request, output_stream) | ||||||
|         , m_override_ca_certificates(override_certificates) |         , m_override_ca_certificates(override_certificates) | ||||||
|     { |     { | ||||||
|     } |     } | ||||||
|  | @ -61,7 +61,7 @@ protected: | ||||||
|     virtual bool can_read() const override; |     virtual bool can_read() const override; | ||||||
|     virtual ByteBuffer receive(size_t) override; |     virtual ByteBuffer receive(size_t) override; | ||||||
|     virtual bool eof() const override; |     virtual bool eof() const override; | ||||||
|     virtual bool write(const ByteBuffer&) override; |     virtual bool write(ReadonlyBytes) override; | ||||||
|     virtual bool is_established() const override { return m_socket->is_established(); } |     virtual bool is_established() const override { return m_socket->is_established(); } | ||||||
|     virtual bool should_fail_on_empty_payload() const override { return false; } |     virtual bool should_fail_on_empty_payload() const override { return false; } | ||||||
|     virtual void read_while_data_available(Function<IterationDecision()>) override; |     virtual void read_while_data_available(Function<IterationDecision()>) override; | ||||||
|  |  | ||||||
|  | @ -28,9 +28,8 @@ | ||||||
| 
 | 
 | ||||||
| namespace Gemini { | namespace Gemini { | ||||||
| 
 | 
 | ||||||
| GeminiResponse::GeminiResponse(int status, String meta, ByteBuffer&& payload) | GeminiResponse::GeminiResponse(int status, String meta) | ||||||
|     : Core::NetworkResponse(move(payload)) |     : m_status(status) | ||||||
|     , m_status(status) |  | ||||||
|     , m_meta(meta) |     , m_meta(meta) | ||||||
| { | { | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -34,16 +34,16 @@ namespace Gemini { | ||||||
| class GeminiResponse : public Core::NetworkResponse { | class GeminiResponse : public Core::NetworkResponse { | ||||||
| public: | public: | ||||||
|     virtual ~GeminiResponse() override; |     virtual ~GeminiResponse() override; | ||||||
|     static NonnullRefPtr<GeminiResponse> create(int status, String meta, ByteBuffer&& payload) |     static NonnullRefPtr<GeminiResponse> create(int status, String meta) | ||||||
|     { |     { | ||||||
|         return adopt(*new GeminiResponse(status, meta, move(payload))); |         return adopt(*new GeminiResponse(status, meta)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     int status() const { return m_status; } |     int status() const { return m_status; } | ||||||
|     String meta() const { return m_meta; } |     String meta() const { return m_meta; } | ||||||
| 
 | 
 | ||||||
| private: | private: | ||||||
|     GeminiResponse(int status, String, ByteBuffer&&); |     GeminiResponse(int status, String); | ||||||
| 
 | 
 | ||||||
|     int m_status { 0 }; |     int m_status { 0 }; | ||||||
|     String m_meta; |     String m_meta; | ||||||
|  |  | ||||||
|  | @ -33,8 +33,9 @@ | ||||||
| 
 | 
 | ||||||
| namespace Gemini { | namespace Gemini { | ||||||
| 
 | 
 | ||||||
| Job::Job(const GeminiRequest& request) | Job::Job(const GeminiRequest& request, OutputStream& output_stream) | ||||||
|     : m_request(request) |     : Core::NetworkJob(output_stream) | ||||||
|  |     , m_request(request) | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -42,6 +43,23 @@ Job::~Job() | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | void Job::flush_received_buffers() | ||||||
|  | { | ||||||
|  |     for (size_t i = 0; i < m_received_buffers.size(); ++i) { | ||||||
|  |         auto& payload = m_received_buffers[i]; | ||||||
|  |         auto written = do_write(payload); | ||||||
|  |         m_received_size -= written; | ||||||
|  |         if (written == payload.size()) { | ||||||
|  |             // FIXME: Make this a take-first-friendly object?
 | ||||||
|  |             m_received_buffers.take_first(); | ||||||
|  |             continue; | ||||||
|  |         } | ||||||
|  |         ASSERT(written < payload.size()); | ||||||
|  |         payload = payload.slice(written, payload.size() - written); | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| void Job::on_socket_connected() | void Job::on_socket_connected() | ||||||
| { | { | ||||||
|     register_on_ready_to_write([this] { |     register_on_ready_to_write([this] { | ||||||
|  | @ -126,6 +144,7 @@ void Job::on_socket_connected() | ||||||
| 
 | 
 | ||||||
|             m_received_buffers.append(payload); |             m_received_buffers.append(payload); | ||||||
|             m_received_size += payload.size(); |             m_received_size += payload.size(); | ||||||
|  |             flush_received_buffers(); | ||||||
| 
 | 
 | ||||||
|             deferred_invoke([this](auto&) { did_progress({}, m_received_size); }); |             deferred_invoke([this](auto&) { did_progress({}, m_received_size); }); | ||||||
| 
 | 
 | ||||||
|  | @ -144,15 +163,17 @@ void Job::on_socket_connected() | ||||||
| void Job::finish_up() | void Job::finish_up() | ||||||
| { | { | ||||||
|     m_state = State::Finished; |     m_state = State::Finished; | ||||||
|     auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size); |     flush_received_buffers(); | ||||||
|     u8* flat_ptr = flattened_buffer.data(); |     if (m_received_size != 0) { | ||||||
|     for (auto& received_buffer : m_received_buffers) { |         // FIXME: What do we do? ignore it?
 | ||||||
|         memcpy(flat_ptr, received_buffer.data(), received_buffer.size()); |         //        "Transmission failed" is not strictly correct, but let's roll with it for now.
 | ||||||
|         flat_ptr += received_buffer.size(); |         deferred_invoke([this](auto&) { | ||||||
|  |             did_fail(Error::TransmissionFailed); | ||||||
|  |         }); | ||||||
|  |         return; | ||||||
|     } |     } | ||||||
|     m_received_buffers.clear(); |  | ||||||
| 
 | 
 | ||||||
|     auto response = GeminiResponse::create(m_status, m_meta, move(flattened_buffer)); |     auto response = GeminiResponse::create(m_status, m_meta); | ||||||
|     deferred_invoke([this, response](auto&) { |     deferred_invoke([this, response](auto&) { | ||||||
|         did_finish(move(response)); |         did_finish(move(response)); | ||||||
|     }); |     }); | ||||||
|  |  | ||||||
|  | @ -36,7 +36,7 @@ namespace Gemini { | ||||||
| 
 | 
 | ||||||
| class Job : public Core::NetworkJob { | class Job : public Core::NetworkJob { | ||||||
| public: | public: | ||||||
|     explicit Job(const GeminiRequest&); |     explicit Job(const GeminiRequest&, OutputStream&); | ||||||
|     virtual ~Job() override; |     virtual ~Job() override; | ||||||
| 
 | 
 | ||||||
|     virtual void start() override = 0; |     virtual void start() override = 0; | ||||||
|  | @ -48,6 +48,7 @@ public: | ||||||
| protected: | protected: | ||||||
|     void finish_up(); |     void finish_up(); | ||||||
|     void on_socket_connected(); |     void on_socket_connected(); | ||||||
|  |     void flush_received_buffers(); | ||||||
|     virtual void register_on_ready_to_read(Function<void()>) = 0; |     virtual void register_on_ready_to_read(Function<void()>) = 0; | ||||||
|     virtual void register_on_ready_to_write(Function<void()>) = 0; |     virtual void register_on_ready_to_write(Function<void()>) = 0; | ||||||
|     virtual bool can_read_line() const = 0; |     virtual bool can_read_line() const = 0; | ||||||
|  | @ -55,7 +56,7 @@ protected: | ||||||
|     virtual bool can_read() const = 0; |     virtual bool can_read() const = 0; | ||||||
|     virtual ByteBuffer receive(size_t) = 0; |     virtual ByteBuffer receive(size_t) = 0; | ||||||
|     virtual bool eof() const = 0; |     virtual bool eof() const = 0; | ||||||
|     virtual bool write(const ByteBuffer&) = 0; |     virtual bool write(ReadonlyBytes) = 0; | ||||||
|     virtual bool is_established() const = 0; |     virtual bool is_established() const = 0; | ||||||
|     virtual bool should_fail_on_empty_payload() const { return false; } |     virtual bool should_fail_on_empty_payload() const { return false; } | ||||||
|     virtual void read_while_data_available(Function<IterationDecision()> read) { read(); }; |     virtual void read_while_data_available(Function<IterationDecision()> read) { read(); }; | ||||||
|  | @ -70,7 +71,7 @@ protected: | ||||||
|     State m_state { State::InStatus }; |     State m_state { State::InStatus }; | ||||||
|     int m_status { -1 }; |     int m_status { -1 }; | ||||||
|     String m_meta; |     String m_meta; | ||||||
|     Vector<ByteBuffer> m_received_buffers; |     Vector<ByteBuffer, 2> m_received_buffers; | ||||||
|     size_t m_received_size { 0 }; |     size_t m_received_size { 0 }; | ||||||
|     bool m_sent_data { false }; |     bool m_sent_data { false }; | ||||||
|     bool m_should_have_payload { false }; |     bool m_should_have_payload { false }; | ||||||
|  |  | ||||||
|  | @ -98,9 +98,9 @@ bool HttpJob::eof() const | ||||||
|     return m_socket->eof(); |     return m_socket->eof(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| bool HttpJob::write(const ByteBuffer& data) | bool HttpJob::write(ReadonlyBytes bytes) | ||||||
| { | { | ||||||
|     return m_socket->write(data); |     return m_socket->write(bytes); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -38,8 +38,8 @@ namespace HTTP { | ||||||
| class HttpJob final : public Job { | class HttpJob final : public Job { | ||||||
|     C_OBJECT(HttpJob) |     C_OBJECT(HttpJob) | ||||||
| public: | public: | ||||||
|     explicit HttpJob(const HttpRequest& request) |     explicit HttpJob(const HttpRequest& request, OutputStream& output_stream) | ||||||
|         : Job(request) |         : Job(request, output_stream) | ||||||
|     { |     { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -59,7 +59,7 @@ protected: | ||||||
|     virtual bool can_read() const override; |     virtual bool can_read() const override; | ||||||
|     virtual ByteBuffer receive(size_t) override; |     virtual ByteBuffer receive(size_t) override; | ||||||
|     virtual bool eof() const override; |     virtual bool eof() const override; | ||||||
|     virtual bool write(const ByteBuffer&) override; |     virtual bool write(ReadonlyBytes) override; | ||||||
|     virtual bool is_established() const override { return true; } |     virtual bool is_established() const override { return true; } | ||||||
| 
 | 
 | ||||||
| private: | private: | ||||||
|  |  | ||||||
|  | @ -71,11 +71,12 @@ ByteBuffer HttpRequest::to_raw_request() const | ||||||
|         builder.append(header.value); |         builder.append(header.value); | ||||||
|         builder.append("\r\n"); |         builder.append("\r\n"); | ||||||
|     } |     } | ||||||
|     builder.append("Connection: close\r\n\r\n"); |     builder.append("Connection: close\r\n"); | ||||||
|     if (!m_body.is_empty()) { |     if (!m_body.is_empty()) { | ||||||
|  |         builder.appendff("Content-Length: {}\r\n\r\n", m_body.size()); | ||||||
|         builder.append((const char*)m_body.data(), m_body.size()); |         builder.append((const char*)m_body.data(), m_body.size()); | ||||||
|         builder.append("\r\n"); |  | ||||||
|     } |     } | ||||||
|  |     builder.append("\r\n"); | ||||||
|     return builder.to_byte_buffer(); |     return builder.to_byte_buffer(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -62,7 +62,8 @@ public: | ||||||
|     void set_method(Method method) { m_method = method; } |     void set_method(Method method) { m_method = method; } | ||||||
| 
 | 
 | ||||||
|     const ByteBuffer& body() const { return m_body; } |     const ByteBuffer& body() const { return m_body; } | ||||||
|     void set_body(const ByteBuffer& body) { m_body = body; } |     void set_body(ReadonlyBytes body) { m_body = ByteBuffer::copy(body); } | ||||||
|  |     void set_body(ByteBuffer&& body) { m_body = move(body); } | ||||||
| 
 | 
 | ||||||
|     String method_name() const; |     String method_name() const; | ||||||
|     ByteBuffer to_raw_request() const; |     ByteBuffer to_raw_request() const; | ||||||
|  |  | ||||||
|  | @ -28,9 +28,8 @@ | ||||||
| 
 | 
 | ||||||
| namespace HTTP { | namespace HTTP { | ||||||
| 
 | 
 | ||||||
| HttpResponse::HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers, ByteBuffer&& payload) | HttpResponse::HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers) | ||||||
|     : Core::NetworkResponse(move(payload)) |     : m_code(code) | ||||||
|     , m_code(code) |  | ||||||
|     , m_headers(move(headers)) |     , m_headers(move(headers)) | ||||||
| { | { | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -35,16 +35,16 @@ namespace HTTP { | ||||||
| class HttpResponse : public Core::NetworkResponse { | class HttpResponse : public Core::NetworkResponse { | ||||||
| public: | public: | ||||||
|     virtual ~HttpResponse() override; |     virtual ~HttpResponse() override; | ||||||
|     static NonnullRefPtr<HttpResponse> create(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers, ByteBuffer&& payload) |     static NonnullRefPtr<HttpResponse> create(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers) | ||||||
|     { |     { | ||||||
|         return adopt(*new HttpResponse(code, move(headers), move(payload))); |         return adopt(*new HttpResponse(code, move(headers))); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     int code() const { return m_code; } |     int code() const { return m_code; } | ||||||
|     const HashMap<String, String, CaseInsensitiveStringTraits>& headers() const { return m_headers; } |     const HashMap<String, String, CaseInsensitiveStringTraits>& headers() const { return m_headers; } | ||||||
| 
 | 
 | ||||||
| private: | private: | ||||||
|     HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&&, ByteBuffer&&); |     HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&&); | ||||||
| 
 | 
 | ||||||
|     int m_code { 0 }; |     int m_code { 0 }; | ||||||
|     HashMap<String, String, CaseInsensitiveStringTraits> m_headers; |     HashMap<String, String, CaseInsensitiveStringTraits> m_headers; | ||||||
|  |  | ||||||
|  | @ -143,7 +143,7 @@ bool HttpsJob::eof() const | ||||||
|     return m_socket->eof(); |     return m_socket->eof(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| bool HttpsJob::write(const ByteBuffer& data) | bool HttpsJob::write(ReadonlyBytes data) | ||||||
| { | { | ||||||
|     return m_socket->write(data); |     return m_socket->write(data); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -38,8 +38,8 @@ namespace HTTP { | ||||||
| class HttpsJob final : public Job { | class HttpsJob final : public Job { | ||||||
|     C_OBJECT(HttpsJob) |     C_OBJECT(HttpsJob) | ||||||
| public: | public: | ||||||
|     explicit HttpsJob(const HttpRequest& request, const Vector<Certificate>* override_certs = nullptr) |     explicit HttpsJob(const HttpRequest& request, OutputStream& output_stream, const Vector<Certificate>* override_certs = nullptr) | ||||||
|         : Job(request) |         : Job(request, output_stream) | ||||||
|         , m_override_ca_certificates(override_certs) |         , m_override_ca_certificates(override_certs) | ||||||
|     { |     { | ||||||
|     } |     } | ||||||
|  | @ -62,7 +62,7 @@ protected: | ||||||
|     virtual bool can_read() const override; |     virtual bool can_read() const override; | ||||||
|     virtual ByteBuffer receive(size_t) override; |     virtual ByteBuffer receive(size_t) override; | ||||||
|     virtual bool eof() const override; |     virtual bool eof() const override; | ||||||
|     virtual bool write(const ByteBuffer&) override; |     virtual bool write(ReadonlyBytes) override; | ||||||
|     virtual bool is_established() const override { return m_socket->is_established(); } |     virtual bool is_established() const override { return m_socket->is_established(); } | ||||||
|     virtual bool should_fail_on_empty_payload() const override { return false; } |     virtual bool should_fail_on_empty_payload() const override { return false; } | ||||||
|     virtual void read_while_data_available(Function<IterationDecision()>) override; |     virtual void read_while_data_available(Function<IterationDecision()>) override; | ||||||
|  |  | ||||||
|  | @ -68,8 +68,9 @@ static ByteBuffer handle_content_encoding(const ByteBuffer& buf, const String& c | ||||||
|     return buf; |     return buf; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Job::Job(const HttpRequest& request) | Job::Job(const HttpRequest& request, OutputStream& output_stream) | ||||||
|     : m_request(request) |     : Core::NetworkJob(output_stream) | ||||||
|  |     , m_request(request) | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -77,6 +78,35 @@ Job::~Job() | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | void Job::flush_received_buffers() | ||||||
|  | { | ||||||
|  |     if (!m_can_stream_response || m_buffered_size == 0) | ||||||
|  |         return; | ||||||
|  | #ifdef JOB_DEBUG | ||||||
|  |     dbg() << "Job: Flushing received buffers: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers"; | ||||||
|  | #endif | ||||||
|  |     for (size_t i = 0; i < m_received_buffers.size(); ++i) { | ||||||
|  |         auto& payload = m_received_buffers[i]; | ||||||
|  |         auto written = do_write(payload); | ||||||
|  |         m_buffered_size -= written; | ||||||
|  |         if (written == payload.size()) { | ||||||
|  |             // FIXME: Make this a take-first-friendly object?
 | ||||||
|  |             m_received_buffers.take_first(); | ||||||
|  |             --i; | ||||||
|  |             continue; | ||||||
|  |         } | ||||||
|  |         ASSERT(written < payload.size()); | ||||||
|  |         payload = payload.slice(written, payload.size() - written); | ||||||
|  | #ifdef JOB_DEBUG | ||||||
|  |         dbg() << "Job: Flushing received buffers done: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers"; | ||||||
|  | #endif | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  | #ifdef JOB_DEBUG | ||||||
|  |     dbg() << "Job: Flushing received buffers done: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers"; | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  | 
 | ||||||
| void Job::on_socket_connected() | void Job::on_socket_connected() | ||||||
| { | { | ||||||
|     register_on_ready_to_write([&] { |     register_on_ready_to_write([&] { | ||||||
|  | @ -135,6 +165,8 @@ void Job::on_socket_connected() | ||||||
|                 if (m_state == State::Trailers) { |                 if (m_state == State::Trailers) { | ||||||
|                     return finish_up(); |                     return finish_up(); | ||||||
|                 } else { |                 } else { | ||||||
|  |                     if (on_headers_received) | ||||||
|  |                         on_headers_received(m_headers, m_code > 0 ? m_code : Optional<u32> {}); | ||||||
|                     m_state = State::InBody; |                     m_state = State::InBody; | ||||||
|                 } |                 } | ||||||
|                 return; |                 return; | ||||||
|  | @ -163,6 +195,13 @@ void Job::on_socket_connected() | ||||||
|             } |             } | ||||||
|             auto value = line.substring(name.length() + 2, line.length() - name.length() - 2); |             auto value = line.substring(name.length() + 2, line.length() - name.length() - 2); | ||||||
|             m_headers.set(name, value); |             m_headers.set(name, value); | ||||||
|  |             if (name.equals_ignoring_case("Content-Encoding")) { | ||||||
|  |                 // Assume that any content-encoding means that we can't decode it as a stream :(
 | ||||||
|  | #ifdef JOB_DEBUG | ||||||
|  |                 dbg() << "Content-Encoding " << value << " detected, cannot stream output :("; | ||||||
|  | #endif | ||||||
|  |                 m_can_stream_response = false; | ||||||
|  |             } | ||||||
| #ifdef JOB_DEBUG | #ifdef JOB_DEBUG | ||||||
|             dbg() << "Job: [" << name << "] = '" << value << "'"; |             dbg() << "Job: [" << name << "] = '" << value << "'"; | ||||||
| #endif | #endif | ||||||
|  | @ -252,7 +291,9 @@ void Job::on_socket_connected() | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             m_received_buffers.append(payload); |             m_received_buffers.append(payload); | ||||||
|  |             m_buffered_size += payload.size(); | ||||||
|             m_received_size += payload.size(); |             m_received_size += payload.size(); | ||||||
|  |             flush_received_buffers(); | ||||||
| 
 | 
 | ||||||
|             if (m_current_chunk_remaining_size.has_value()) { |             if (m_current_chunk_remaining_size.has_value()) { | ||||||
|                 auto size = m_current_chunk_remaining_size.value() - payload.size(); |                 auto size = m_current_chunk_remaining_size.value() - payload.size(); | ||||||
|  | @ -313,20 +354,37 @@ void Job::on_socket_connected() | ||||||
| void Job::finish_up() | void Job::finish_up() | ||||||
| { | { | ||||||
|     m_state = State::Finished; |     m_state = State::Finished; | ||||||
|     auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size); |     if (!m_can_stream_response) { | ||||||
|     u8* flat_ptr = flattened_buffer.data(); |         auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size); | ||||||
|     for (auto& received_buffer : m_received_buffers) { |         u8* flat_ptr = flattened_buffer.data(); | ||||||
|         memcpy(flat_ptr, received_buffer.data(), received_buffer.size()); |         for (auto& received_buffer : m_received_buffers) { | ||||||
|         flat_ptr += received_buffer.size(); |             memcpy(flat_ptr, received_buffer.data(), received_buffer.size()); | ||||||
|     } |             flat_ptr += received_buffer.size(); | ||||||
|     m_received_buffers.clear(); |         } | ||||||
|  |         m_received_buffers.clear(); | ||||||
| 
 | 
 | ||||||
|     auto content_encoding = m_headers.get("Content-Encoding"); |         // For the time being, we cannot stream stuff with content-encoding set to _anything_.
 | ||||||
|     if (content_encoding.has_value()) { |         auto content_encoding = m_headers.get("Content-Encoding"); | ||||||
|         flattened_buffer = handle_content_encoding(flattened_buffer, content_encoding.value()); |         if (content_encoding.has_value()) { | ||||||
|  |             flattened_buffer = handle_content_encoding(flattened_buffer, content_encoding.value()); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         m_buffered_size = flattened_buffer.size(); | ||||||
|  |         m_received_buffers.append(move(flattened_buffer)); | ||||||
|  |         m_can_stream_response = true; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     auto response = HttpResponse::create(m_code, move(m_headers), move(flattened_buffer)); |     flush_received_buffers(); | ||||||
|  |     if (m_buffered_size != 0) { | ||||||
|  |         // FIXME: What do we do? ignore it?
 | ||||||
|  |         //        "Transmission failed" is not strictly correct, but let's roll with it for now.
 | ||||||
|  |         deferred_invoke([this](auto&) { | ||||||
|  |             did_fail(Error::TransmissionFailed); | ||||||
|  |         }); | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto response = HttpResponse::create(m_code, move(m_headers)); | ||||||
|     deferred_invoke([this, response](auto&) { |     deferred_invoke([this, response](auto&) { | ||||||
|         did_finish(move(response)); |         did_finish(move(response)); | ||||||
|     }); |     }); | ||||||
|  |  | ||||||
|  | @ -26,6 +26,7 @@ | ||||||
| 
 | 
 | ||||||
| #pragma once | #pragma once | ||||||
| 
 | 
 | ||||||
|  | #include <AK/FileStream.h> | ||||||
| #include <AK/HashMap.h> | #include <AK/HashMap.h> | ||||||
| #include <AK/Optional.h> | #include <AK/Optional.h> | ||||||
| #include <LibCore/NetworkJob.h> | #include <LibCore/NetworkJob.h> | ||||||
|  | @ -37,7 +38,7 @@ namespace HTTP { | ||||||
| 
 | 
 | ||||||
| class Job : public Core::NetworkJob { | class Job : public Core::NetworkJob { | ||||||
| public: | public: | ||||||
|     explicit Job(const HttpRequest&); |     explicit Job(const HttpRequest&, OutputStream&); | ||||||
|     virtual ~Job() override; |     virtual ~Job() override; | ||||||
| 
 | 
 | ||||||
|     virtual void start() override = 0; |     virtual void start() override = 0; | ||||||
|  | @ -49,6 +50,7 @@ public: | ||||||
| protected: | protected: | ||||||
|     void finish_up(); |     void finish_up(); | ||||||
|     void on_socket_connected(); |     void on_socket_connected(); | ||||||
|  |     void flush_received_buffers(); | ||||||
|     virtual void register_on_ready_to_read(Function<void()>) = 0; |     virtual void register_on_ready_to_read(Function<void()>) = 0; | ||||||
|     virtual void register_on_ready_to_write(Function<void()>) = 0; |     virtual void register_on_ready_to_write(Function<void()>) = 0; | ||||||
|     virtual bool can_read_line() const = 0; |     virtual bool can_read_line() const = 0; | ||||||
|  | @ -56,7 +58,7 @@ protected: | ||||||
|     virtual bool can_read() const = 0; |     virtual bool can_read() const = 0; | ||||||
|     virtual ByteBuffer receive(size_t) = 0; |     virtual ByteBuffer receive(size_t) = 0; | ||||||
|     virtual bool eof() const = 0; |     virtual bool eof() const = 0; | ||||||
|     virtual bool write(const ByteBuffer&) = 0; |     virtual bool write(ReadonlyBytes) = 0; | ||||||
|     virtual bool is_established() const = 0; |     virtual bool is_established() const = 0; | ||||||
|     virtual bool should_fail_on_empty_payload() const { return true; } |     virtual bool should_fail_on_empty_payload() const { return true; } | ||||||
|     virtual void read_while_data_available(Function<IterationDecision()> read) { read(); }; |     virtual void read_while_data_available(Function<IterationDecision()> read) { read(); }; | ||||||
|  | @ -73,11 +75,13 @@ protected: | ||||||
|     State m_state { State::InStatus }; |     State m_state { State::InStatus }; | ||||||
|     int m_code { -1 }; |     int m_code { -1 }; | ||||||
|     HashMap<String, String, CaseInsensitiveStringTraits> m_headers; |     HashMap<String, String, CaseInsensitiveStringTraits> m_headers; | ||||||
|     Vector<ByteBuffer> m_received_buffers; |     Vector<ByteBuffer, 2> m_received_buffers; | ||||||
|  |     size_t m_buffered_size { 0 }; | ||||||
|     size_t m_received_size { 0 }; |     size_t m_received_size { 0 }; | ||||||
|     bool m_sent_data { 0 }; |     bool m_sent_data { 0 }; | ||||||
|     Optional<ssize_t> m_current_chunk_remaining_size; |     Optional<ssize_t> m_current_chunk_remaining_size; | ||||||
|     Optional<size_t> m_current_chunk_total_size; |     Optional<size_t> m_current_chunk_total_size; | ||||||
|  |     bool m_can_stream_response { true }; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -24,6 +24,7 @@ | ||||||
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
|  | #include <AK/FileStream.h> | ||||||
| #include <AK/SharedBuffer.h> | #include <AK/SharedBuffer.h> | ||||||
| #include <LibProtocol/Client.h> | #include <LibProtocol/Client.h> | ||||||
| #include <LibProtocol/Download.h> | #include <LibProtocol/Download.h> | ||||||
|  | @ -47,16 +48,20 @@ bool Client::is_supported_protocol(const String& protocol) | ||||||
|     return send_sync<Messages::ProtocolServer::IsSupportedProtocol>(protocol)->supported(); |     return send_sync<Messages::ProtocolServer::IsSupportedProtocol>(protocol)->supported(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| RefPtr<Download> Client::start_download(const String& method, const String& url, const HashMap<String, String>& request_headers, const ByteBuffer& request_body) | template<typename RequestHashMapTraits> | ||||||
|  | RefPtr<Download> Client::start_download(const String& method, const String& url, const HashMap<String, String, RequestHashMapTraits>& request_headers, ReadonlyBytes request_body) | ||||||
| { | { | ||||||
|     IPC::Dictionary header_dictionary; |     IPC::Dictionary header_dictionary; | ||||||
|     for (auto& it : request_headers) |     for (auto& it : request_headers) | ||||||
|         header_dictionary.add(it.key, it.value); |         header_dictionary.add(it.key, it.value); | ||||||
| 
 | 
 | ||||||
|     i32 download_id = send_sync<Messages::ProtocolServer::StartDownload>(method, url, header_dictionary, String::copy(request_body))->download_id(); |     auto response = send_sync<Messages::ProtocolServer::StartDownload>(method, url, header_dictionary, ByteBuffer::copy(request_body)); | ||||||
|     if (download_id < 0) |     auto download_id = response->download_id(); | ||||||
|  |     auto response_fd = response->response_fd().fd(); | ||||||
|  |     if (download_id < 0 || response_fd < 0) | ||||||
|         return nullptr; |         return nullptr; | ||||||
|     auto download = Download::create_from_id({}, *this, download_id); |     auto download = Download::create_from_id({}, *this, download_id); | ||||||
|  |     download->set_download_fd({}, response_fd); | ||||||
|     m_downloads.set(download_id, download); |     m_downloads.set(download_id, download); | ||||||
|     return download; |     return download; | ||||||
| } | } | ||||||
|  | @ -79,9 +84,8 @@ void Client::handle(const Messages::ProtocolClient::DownloadFinished& message) | ||||||
| { | { | ||||||
|     RefPtr<Download> download; |     RefPtr<Download> download; | ||||||
|     if ((download = m_downloads.get(message.download_id()).value_or(nullptr))) { |     if ((download = m_downloads.get(message.download_id()).value_or(nullptr))) { | ||||||
|         download->did_finish({}, message.success(), message.status_code(), message.total_size(), message.shbuf_id(), message.response_headers()); |         download->did_finish({}, message.success(), message.total_size()); | ||||||
|     } |     } | ||||||
|     send_sync<Messages::ProtocolServer::DisownSharedBuffer>(message.shbuf_id()); |  | ||||||
|     m_downloads.remove(message.download_id()); |     m_downloads.remove(message.download_id()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -92,6 +96,15 @@ void Client::handle(const Messages::ProtocolClient::DownloadProgress& message) | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | void Client::handle(const Messages::ProtocolClient::HeadersBecameAvailable& message) | ||||||
|  | { | ||||||
|  |     if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) { | ||||||
|  |         HashMap<String, String, CaseInsensitiveStringTraits> headers; | ||||||
|  |         message.response_headers().for_each_entry([&](auto& name, auto& value) { headers.set(name, value); }); | ||||||
|  |         download->did_receive_headers({}, headers, message.status_code()); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(const Messages::ProtocolClient::CertificateRequested& message) | OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(const Messages::ProtocolClient::CertificateRequested& message) | ||||||
| { | { | ||||||
|     if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) { |     if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) { | ||||||
|  | @ -102,3 +115,6 @@ OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(co | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | template RefPtr<Protocol::Download> Protocol::Client::start_download(const String& method, const String& url, const HashMap<String, String>& request_headers, ReadonlyBytes request_body); | ||||||
|  | template RefPtr<Protocol::Download> Protocol::Client::start_download(const String& method, const String& url, const HashMap<String, String, CaseInsensitiveStringTraits>& request_headers, ReadonlyBytes request_body); | ||||||
|  |  | ||||||
|  | @ -44,7 +44,8 @@ public: | ||||||
|     virtual void handshake() override; |     virtual void handshake() override; | ||||||
| 
 | 
 | ||||||
|     bool is_supported_protocol(const String&); |     bool is_supported_protocol(const String&); | ||||||
|     RefPtr<Download> start_download(const String& method, const String& url, const HashMap<String, String>& request_headers = {}, const ByteBuffer& request_body = {}); |     template<typename RequestHashMapTraits = Traits<String>> | ||||||
|  |     RefPtr<Download> start_download(const String& method, const String& url, const HashMap<String, String, RequestHashMapTraits>& request_headers = {}, ReadonlyBytes request_body = {}); | ||||||
| 
 | 
 | ||||||
|     bool stop_download(Badge<Download>, Download&); |     bool stop_download(Badge<Download>, Download&); | ||||||
|     bool set_certificate(Badge<Download>, Download&, String, String); |     bool set_certificate(Badge<Download>, Download&, String, String); | ||||||
|  | @ -55,6 +56,7 @@ private: | ||||||
|     virtual void handle(const Messages::ProtocolClient::DownloadProgress&) override; |     virtual void handle(const Messages::ProtocolClient::DownloadProgress&) override; | ||||||
|     virtual void handle(const Messages::ProtocolClient::DownloadFinished&) override; |     virtual void handle(const Messages::ProtocolClient::DownloadFinished&) override; | ||||||
|     virtual OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> handle(const Messages::ProtocolClient::CertificateRequested&) override; |     virtual OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> handle(const Messages::ProtocolClient::CertificateRequested&) override; | ||||||
|  |     virtual void handle(const Messages::ProtocolClient::HeadersBecameAvailable&) override; | ||||||
| 
 | 
 | ||||||
|     HashMap<i32, RefPtr<Download>> m_downloads; |     HashMap<i32, RefPtr<Download>> m_downloads; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | @ -41,25 +41,81 @@ bool Download::stop() | ||||||
|     return m_client->stop_download({}, *this); |     return m_client->stop_download({}, *this); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void Download::did_finish(Badge<Client>, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers) | void Download::stream_into(OutputStream& stream) | ||||||
|  | { | ||||||
|  |     ASSERT(!m_internal_stream_data); | ||||||
|  | 
 | ||||||
|  |     auto notifier = Core::Notifier::construct(fd(), Core::Notifier::Read); | ||||||
|  | 
 | ||||||
|  |     m_internal_stream_data = make<InternalStreamData>(fd()); | ||||||
|  |     m_internal_stream_data->read_notifier = notifier; | ||||||
|  | 
 | ||||||
|  |     auto user_on_finish = move(on_finish); | ||||||
|  |     on_finish = [this](auto success, auto total_size) { | ||||||
|  |         m_internal_stream_data->success = success; | ||||||
|  |         m_internal_stream_data->total_size = total_size; | ||||||
|  |         m_internal_stream_data->download_done = true; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     notifier->on_ready_to_read = [this, &stream, user_on_finish = move(user_on_finish)] { | ||||||
|  |         constexpr size_t buffer_size = 1 * KiB; | ||||||
|  |         static char buf[buffer_size]; | ||||||
|  |         auto nread = m_internal_stream_data->read_stream.read({ buf, buffer_size }); | ||||||
|  |         if (!stream.write_or_error({ buf, nread })) { | ||||||
|  |             // FIXME: What do we do here?
 | ||||||
|  |             TODO(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if (m_internal_stream_data->read_stream.eof() || (m_internal_stream_data->download_done && !m_internal_stream_data->success)) { | ||||||
|  |             m_internal_stream_data->read_notifier->close(); | ||||||
|  |             user_on_finish(m_internal_stream_data->success, m_internal_stream_data->total_size); | ||||||
|  |         } else { | ||||||
|  |             m_internal_stream_data->read_stream.handle_any_error(); | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void Download::set_should_buffer_all_input(bool value) | ||||||
|  | { | ||||||
|  |     if (m_should_buffer_all_input == value) | ||||||
|  |         return; | ||||||
|  | 
 | ||||||
|  |     if (m_internal_buffered_data && !value) { | ||||||
|  |         m_internal_buffered_data = nullptr; | ||||||
|  |         m_should_buffer_all_input = false; | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ASSERT(!m_internal_stream_data); | ||||||
|  |     ASSERT(!m_internal_buffered_data); | ||||||
|  |     ASSERT(on_buffered_download_finish); // Not having this set makes no sense.
 | ||||||
|  |     m_internal_buffered_data = make<InternalBufferedData>(fd()); | ||||||
|  |     m_should_buffer_all_input = true; | ||||||
|  | 
 | ||||||
|  |     on_headers_received = [this](auto& headers, auto response_code) { | ||||||
|  |         m_internal_buffered_data->response_headers = headers; | ||||||
|  |         m_internal_buffered_data->response_code = move(response_code); | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     on_finish = [this](auto success, u32 total_size) { | ||||||
|  |         auto output_buffer = m_internal_buffered_data->payload_stream.copy_into_contiguous_buffer(); | ||||||
|  |         on_buffered_download_finish( | ||||||
|  |             success, | ||||||
|  |             total_size, | ||||||
|  |             m_internal_buffered_data->response_headers, | ||||||
|  |             m_internal_buffered_data->response_code, | ||||||
|  |             output_buffer); | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     stream_into(m_internal_buffered_data->payload_stream); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void Download::did_finish(Badge<Client>, bool success, u32 total_size) | ||||||
| { | { | ||||||
|     if (!on_finish) |     if (!on_finish) | ||||||
|         return; |         return; | ||||||
| 
 | 
 | ||||||
|     ReadonlyBytes payload; |     on_finish(success, total_size); | ||||||
|     RefPtr<SharedBuffer> shared_buffer; |  | ||||||
|     if (success && shbuf_id != -1) { |  | ||||||
|         shared_buffer = SharedBuffer::create_from_shbuf_id(shbuf_id); |  | ||||||
|         payload = { shared_buffer->data<void>(), total_size }; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // FIXME: It's a bit silly that we copy the response headers here just so we can move them into a HashMap with different traits.
 |  | ||||||
|     HashMap<String, String, CaseInsensitiveStringTraits> caseless_response_headers; |  | ||||||
|     response_headers.for_each_entry([&](auto& name, auto& value) { |  | ||||||
|         caseless_response_headers.set(name, value); |  | ||||||
|     }); |  | ||||||
| 
 |  | ||||||
|     on_finish(success, payload, move(shared_buffer), caseless_response_headers, status_code); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size) | void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size) | ||||||
|  | @ -68,6 +124,12 @@ void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloa | ||||||
|         on_progress(total_size, downloaded_size); |         on_progress(total_size, downloaded_size); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | void Download::did_receive_headers(Badge<Client>, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code) | ||||||
|  | { | ||||||
|  |     if (on_headers_received) | ||||||
|  |         on_headers_received(response_headers, response_code); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| void Download::did_request_certificates(Badge<Client>) | void Download::did_request_certificates(Badge<Client>) | ||||||
| { | { | ||||||
|     if (on_certificate_requested) { |     if (on_certificate_requested) { | ||||||
|  |  | ||||||
|  | @ -28,10 +28,13 @@ | ||||||
| 
 | 
 | ||||||
| #include <AK/Badge.h> | #include <AK/Badge.h> | ||||||
| #include <AK/ByteBuffer.h> | #include <AK/ByteBuffer.h> | ||||||
|  | #include <AK/FileStream.h> | ||||||
| #include <AK/Function.h> | #include <AK/Function.h> | ||||||
|  | #include <AK/MemoryStream.h> | ||||||
| #include <AK/RefCounted.h> | #include <AK/RefCounted.h> | ||||||
| #include <AK/String.h> | #include <AK/String.h> | ||||||
| #include <AK/WeakPtr.h> | #include <AK/WeakPtr.h> | ||||||
|  | #include <LibCore/Notifier.h> | ||||||
| #include <LibIPC/Forward.h> | #include <LibIPC/Forward.h> | ||||||
| 
 | 
 | ||||||
| namespace Protocol { | namespace Protocol { | ||||||
|  | @ -51,20 +54,65 @@ public: | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     int id() const { return m_download_id; } |     int id() const { return m_download_id; } | ||||||
|  |     int fd() const { return m_fd; } | ||||||
|     bool stop(); |     bool stop(); | ||||||
| 
 | 
 | ||||||
|     Function<void(bool success, ReadonlyBytes payload, RefPtr<SharedBuffer> payload_storage, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> status_code)> on_finish; |     void stream_into(OutputStream&); | ||||||
|  | 
 | ||||||
|  |     bool should_buffer_all_input() const { return m_should_buffer_all_input; } | ||||||
|  |     /// Note: Will override `on_finish', and `on_headers_received', and expects `on_buffered_download_finish' to be set!
 | ||||||
|  |     void set_should_buffer_all_input(bool); | ||||||
|  | 
 | ||||||
|  |     /// Note: Must be set before `set_should_buffer_all_input(true)`.
 | ||||||
|  |     Function<void(bool success, u32 total_size, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code, ReadonlyBytes payload)> on_buffered_download_finish; | ||||||
|  |     Function<void(bool success, u32 total_size)> on_finish; | ||||||
|     Function<void(Optional<u32> total_size, u32 downloaded_size)> on_progress; |     Function<void(Optional<u32> total_size, u32 downloaded_size)> on_progress; | ||||||
|  |     Function<void(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)> on_headers_received; | ||||||
|     Function<CertificateAndKey()> on_certificate_requested; |     Function<CertificateAndKey()> on_certificate_requested; | ||||||
| 
 | 
 | ||||||
|     void did_finish(Badge<Client>, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers); |     void did_finish(Badge<Client>, bool success, u32 total_size); | ||||||
|     void did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size); |     void did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size); | ||||||
|  |     void did_receive_headers(Badge<Client>, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code); | ||||||
|     void did_request_certificates(Badge<Client>); |     void did_request_certificates(Badge<Client>); | ||||||
| 
 | 
 | ||||||
|  |     RefPtr<Core::Notifier>& write_notifier(Badge<Client>) { return m_write_notifier; } | ||||||
|  |     void set_download_fd(Badge<Client>, int fd) { m_fd = fd; } | ||||||
|  | 
 | ||||||
| private: | private: | ||||||
|     explicit Download(Client&, i32 download_id); |     explicit Download(Client&, i32 download_id); | ||||||
|     WeakPtr<Client> m_client; |     WeakPtr<Client> m_client; | ||||||
|     int m_download_id { -1 }; |     int m_download_id { -1 }; | ||||||
|  |     RefPtr<Core::Notifier> m_write_notifier; | ||||||
|  |     int m_fd { -1 }; | ||||||
|  |     bool m_should_buffer_all_input { false }; | ||||||
|  | 
 | ||||||
|  |     struct InternalBufferedData { | ||||||
|  |         InternalBufferedData(int fd) | ||||||
|  |             : read_stream(fd) | ||||||
|  |         { | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         InputFileStream read_stream; | ||||||
|  |         DuplexMemoryStream payload_stream; | ||||||
|  |         HashMap<String, String, CaseInsensitiveStringTraits> response_headers; | ||||||
|  |         Optional<u32> response_code; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     struct InternalStreamData { | ||||||
|  |         InternalStreamData(int fd) | ||||||
|  |             : read_stream(fd) | ||||||
|  |         { | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         InputFileStream read_stream; | ||||||
|  |         RefPtr<Core::Notifier> read_notifier; | ||||||
|  |         bool success; | ||||||
|  |         u32 total_size { 0 }; | ||||||
|  |         bool download_done { false }; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     OwnPtr<InternalBufferedData> m_internal_buffered_data; | ||||||
|  |     OwnPtr<InternalStreamData> m_internal_stream_data; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -92,10 +92,10 @@ void XMLHttpRequest::send() | ||||||
|     // we need to make ResourceLoader give us more detailed updates than just "done" and "error".
 |     // we need to make ResourceLoader give us more detailed updates than just "done" and "error".
 | ||||||
|     ResourceLoader::the().load( |     ResourceLoader::the().load( | ||||||
|         m_window->document().complete_url(m_url), |         m_window->document().complete_url(m_url), | ||||||
|         [weak_this = make_weak_ptr()](auto& data, auto&) { |         [weak_this = make_weak_ptr()](auto data, auto&) { | ||||||
|             if (!weak_this) |             if (!weak_this) | ||||||
|                 return; |                 return; | ||||||
|             const_cast<XMLHttpRequest&>(*weak_this).m_response = data; |             const_cast<XMLHttpRequest&>(*weak_this).m_response = ByteBuffer::copy(data); | ||||||
|             const_cast<XMLHttpRequest&>(*weak_this).set_ready_state(ReadyState::Done); |             const_cast<XMLHttpRequest&>(*weak_this).set_ready_state(ReadyState::Done); | ||||||
|             const_cast<XMLHttpRequest&>(*weak_this).dispatch_event(DOM::Event::create(HTML::EventNames::load)); |             const_cast<XMLHttpRequest&>(*weak_this).dispatch_event(DOM::Event::create(HTML::EventNames::load)); | ||||||
|         }, |         }, | ||||||
|  |  | ||||||
|  | @ -128,7 +128,7 @@ void HTMLScriptElement::prepare_script(Badge<HTMLDocumentParser>) | ||||||
|         // FIXME: This load should be made asynchronous and the parser should spin an event loop etc.
 |         // FIXME: This load should be made asynchronous and the parser should spin an event loop etc.
 | ||||||
|         ResourceLoader::the().load_sync( |         ResourceLoader::the().load_sync( | ||||||
|             url, |             url, | ||||||
|             [this, url](auto& data, auto&) { |             [this, url](auto data, auto&) { | ||||||
|                 if (data.is_null()) { |                 if (data.is_null()) { | ||||||
|                     dbg() << "HTMLScriptElement: Failed to load " << url; |                     dbg() << "HTMLScriptElement: Failed to load " << url; | ||||||
|                     return; |                     return; | ||||||
|  |  | ||||||
|  | @ -171,6 +171,7 @@ bool FrameLoader::load(const LoadRequest& request, Type type) | ||||||
|         return true; |         return true; | ||||||
| 
 | 
 | ||||||
|     if (url.protocol() == "http" || url.protocol() == "https") { |     if (url.protocol() == "http" || url.protocol() == "https") { | ||||||
|  | #if 0 | ||||||
|         URL favicon_url; |         URL favicon_url; | ||||||
|         favicon_url.set_protocol(url.protocol()); |         favicon_url.set_protocol(url.protocol()); | ||||||
|         favicon_url.set_host(url.host()); |         favicon_url.set_host(url.host()); | ||||||
|  | @ -191,6 +192,7 @@ bool FrameLoader::load(const LoadRequest& request, Type type) | ||||||
|                 if (auto* page = frame().page()) |                 if (auto* page = frame().page()) | ||||||
|                     page->client().page_did_change_favicon(*bitmap); |                     page->client().page_did_change_favicon(*bitmap); | ||||||
|             }); |             }); | ||||||
|  | #endif | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     return true; |     return true; | ||||||
|  |  | ||||||
|  | @ -84,10 +84,10 @@ static String mime_type_from_content_type(const String& content_type) | ||||||
|     return content_type; |     return content_type; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void Resource::did_load(Badge<ResourceLoader>, const ByteBuffer& data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers) | void Resource::did_load(Badge<ResourceLoader>, ReadonlyBytes data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers) | ||||||
| { | { | ||||||
|     ASSERT(!m_loaded); |     ASSERT(!m_loaded); | ||||||
|     m_encoded_data = data; |     m_encoded_data = ByteBuffer::copy(data); | ||||||
|     m_response_headers = headers; |     m_response_headers = headers; | ||||||
|     m_loaded = true; |     m_loaded = true; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -77,7 +77,7 @@ public: | ||||||
| 
 | 
 | ||||||
|     void for_each_client(Function<void(ResourceClient&)>); |     void for_each_client(Function<void(ResourceClient&)>); | ||||||
| 
 | 
 | ||||||
|     void did_load(Badge<ResourceLoader>, const ByteBuffer& data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers); |     void did_load(Badge<ResourceLoader>, ReadonlyBytes data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers); | ||||||
|     void did_fail(Badge<ResourceLoader>, const String& error); |     void did_fail(Badge<ResourceLoader>, const String& error); | ||||||
| 
 | 
 | ||||||
| protected: | protected: | ||||||
|  |  | ||||||
|  | @ -53,13 +53,13 @@ ResourceLoader::ResourceLoader() | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void ResourceLoader::load_sync(const URL& url, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback) | void ResourceLoader::load_sync(const URL& url, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback) | ||||||
| { | { | ||||||
|     Core::EventLoop loop; |     Core::EventLoop loop; | ||||||
| 
 | 
 | ||||||
|     load( |     load( | ||||||
|         url, |         url, | ||||||
|         [&](auto& data, auto& response_headers) { |         [&](auto data, auto& response_headers) { | ||||||
|             success_callback(data, response_headers); |             success_callback(data, response_headers); | ||||||
|             loop.quit(0); |             loop.quit(0); | ||||||
|         }, |         }, | ||||||
|  | @ -97,7 +97,7 @@ RefPtr<Resource> ResourceLoader::load_resource(Resource::Type type, const LoadRe | ||||||
| 
 | 
 | ||||||
|     load( |     load( | ||||||
|         request, |         request, | ||||||
|         [=](auto& data, auto& headers) { |         [=](auto data, auto& headers) { | ||||||
|             const_cast<Resource&>(*resource).did_load({}, data, headers); |             const_cast<Resource&>(*resource).did_load({}, data, headers); | ||||||
|         }, |         }, | ||||||
|         [=](auto& error) { |         [=](auto& error) { | ||||||
|  | @ -107,7 +107,7 @@ RefPtr<Resource> ResourceLoader::load_resource(Resource::Type type, const LoadRe | ||||||
|     return resource; |     return resource; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback) | void ResourceLoader::load(const LoadRequest& request, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback) | ||||||
| { | { | ||||||
|     auto& url = request.url(); |     auto& url = request.url(); | ||||||
|     if (is_port_blocked(url.port())) { |     if (is_port_blocked(url.port())) { | ||||||
|  | @ -170,7 +170,12 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu | ||||||
|                 error_callback("Failed to initiate load"); |                 error_callback("Failed to initiate load"); | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
|         download->on_finish = [this, success_callback = move(success_callback), error_callback = move(error_callback)](bool success, ReadonlyBytes payload, auto, auto& response_headers, auto status_code) { |         download->on_buffered_download_finish = [this, success_callback = move(success_callback), error_callback = move(error_callback), download](bool success, auto, auto& response_headers, auto status_code, ReadonlyBytes payload) { | ||||||
|  |             if (status_code.has_value() && status_code.value() >= 400 && status_code.value() <= 499) { | ||||||
|  |                 if (error_callback) | ||||||
|  |                     error_callback(String::format("HTTP error (%u)", status_code.value())); | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|             --m_pending_loads; |             --m_pending_loads; | ||||||
|             if (on_load_counter_change) |             if (on_load_counter_change) | ||||||
|                 on_load_counter_change(); |                 on_load_counter_change(); | ||||||
|  | @ -179,13 +184,9 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu | ||||||
|                     error_callback("HTTP load failed"); |                     error_callback("HTTP load failed"); | ||||||
|                 return; |                 return; | ||||||
|             } |             } | ||||||
|             if (status_code.has_value() && status_code.value() >= 400 && status_code.value() <= 499) { |             success_callback(payload, response_headers); | ||||||
|                 if (error_callback) |  | ||||||
|                     error_callback(String::format("HTTP error (%u)", status_code.value())); |  | ||||||
|                 return; |  | ||||||
|             } |  | ||||||
|             success_callback(ByteBuffer::copy(payload.data(), payload.size()), response_headers); |  | ||||||
|         }; |         }; | ||||||
|  |         download->set_should_buffer_all_input(true); | ||||||
|         download->on_certificate_requested = []() -> Protocol::Download::CertificateAndKey { |         download->on_certificate_requested = []() -> Protocol::Download::CertificateAndKey { | ||||||
|             return {}; |             return {}; | ||||||
|         }; |         }; | ||||||
|  | @ -199,7 +200,7 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu | ||||||
|         error_callback(String::format("Protocol not implemented: %s", url.protocol().characters())); |         error_callback(String::format("Protocol not implemented: %s", url.protocol().characters())); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void ResourceLoader::load(const URL& url, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback) | void ResourceLoader::load(const URL& url, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback) | ||||||
| { | { | ||||||
|     LoadRequest request; |     LoadRequest request; | ||||||
|     request.set_url(url); |     request.set_url(url); | ||||||
|  |  | ||||||
|  | @ -44,9 +44,9 @@ public: | ||||||
| 
 | 
 | ||||||
|     RefPtr<Resource> load_resource(Resource::Type, const LoadRequest&); |     RefPtr<Resource> load_resource(Resource::Type, const LoadRequest&); | ||||||
| 
 | 
 | ||||||
|     void load(const LoadRequest&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr); |     void load(const LoadRequest&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr); | ||||||
|     void load(const URL&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr); |     void load(const URL&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr); | ||||||
|     void load_sync(const URL&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr); |     void load_sync(const URL&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr); | ||||||
| 
 | 
 | ||||||
|     Function<void()> on_load_counter_change; |     Function<void()> on_load_counter_change; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -62,16 +62,17 @@ OwnPtr<Messages::ProtocolServer::StartDownloadResponse> ClientConnection::handle | ||||||
| { | { | ||||||
|     URL url(message.url()); |     URL url(message.url()); | ||||||
|     if (!url.is_valid()) |     if (!url.is_valid()) | ||||||
|         return make<Messages::ProtocolServer::StartDownloadResponse>(-1); |         return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1); | ||||||
|     auto* protocol = Protocol::find_by_name(url.protocol()); |     auto* protocol = Protocol::find_by_name(url.protocol()); | ||||||
|     if (!protocol) |     if (!protocol) | ||||||
|         return make<Messages::ProtocolServer::StartDownloadResponse>(-1); |         return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1); | ||||||
|     auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body().to_byte_buffer()); |     auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body()); | ||||||
|     if (!download) |     if (!download) | ||||||
|         return make<Messages::ProtocolServer::StartDownloadResponse>(-1); |         return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1); | ||||||
|     auto id = download->id(); |     auto id = download->id(); | ||||||
|  |     auto fd = download->download_fd(); | ||||||
|     m_downloads.set(id, move(download)); |     m_downloads.set(id, move(download)); | ||||||
|     return make<Messages::ProtocolServer::StartDownloadResponse>(id); |     return make<Messages::ProtocolServer::StartDownloadResponse>(id, fd); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle(const Messages::ProtocolServer::StopDownload& message) | OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle(const Messages::ProtocolServer::StopDownload& message) | ||||||
|  | @ -86,22 +87,20 @@ OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle( | ||||||
|     return make<Messages::ProtocolServer::StopDownloadResponse>(success); |     return make<Messages::ProtocolServer::StopDownloadResponse>(success); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void ClientConnection::did_finish_download(Badge<Download>, Download& download, bool success) | void ClientConnection::did_receive_headers(Badge<Download>, Download& download) | ||||||
| { | { | ||||||
|     RefPtr<SharedBuffer> buffer; |  | ||||||
|     if (success && download.payload().size() > 0 && !download.payload().is_null()) { |  | ||||||
|         buffer = SharedBuffer::create_with_size(download.payload().size()); |  | ||||||
|         memcpy(buffer->data<void>(), download.payload().data(), download.payload().size()); |  | ||||||
|         buffer->seal(); |  | ||||||
|         buffer->share_with(client_pid()); |  | ||||||
|         m_shared_buffers.set(buffer->shbuf_id(), buffer); |  | ||||||
|     } |  | ||||||
|     ASSERT(download.total_size().has_value()); |  | ||||||
| 
 |  | ||||||
|     IPC::Dictionary response_headers; |     IPC::Dictionary response_headers; | ||||||
|     for (auto& it : download.response_headers()) |     for (auto& it : download.response_headers()) | ||||||
|         response_headers.add(it.key, it.value); |         response_headers.add(it.key, it.value); | ||||||
|     post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.status_code(), download.total_size().value(), buffer ? buffer->shbuf_id() : -1, response_headers)); | 
 | ||||||
|  |     post_message(Messages::ProtocolClient::HeadersBecameAvailable(download.id(), move(response_headers), download.status_code())); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void ClientConnection::did_finish_download(Badge<Download>, Download& download, bool success) | ||||||
|  | { | ||||||
|  |     ASSERT(download.total_size().has_value()); | ||||||
|  | 
 | ||||||
|  |     post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size().value())); | ||||||
| 
 | 
 | ||||||
|     m_downloads.remove(download.id()); |     m_downloads.remove(download.id()); | ||||||
| } | } | ||||||
|  | @ -121,12 +120,6 @@ OwnPtr<Messages::ProtocolServer::GreetResponse> ClientConnection::handle(const M | ||||||
|     return make<Messages::ProtocolServer::GreetResponse>(client_id()); |     return make<Messages::ProtocolServer::GreetResponse>(client_id()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> ClientConnection::handle(const Messages::ProtocolServer::DisownSharedBuffer& message) |  | ||||||
| { |  | ||||||
|     m_shared_buffers.remove(message.shbuf_id()); |  | ||||||
|     return make<Messages::ProtocolServer::DisownSharedBufferResponse>(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| OwnPtr<Messages::ProtocolServer::SetCertificateResponse> ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message) | OwnPtr<Messages::ProtocolServer::SetCertificateResponse> ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message) | ||||||
| { | { | ||||||
|     auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr)); |     auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr)); | ||||||
|  |  | ||||||
|  | @ -45,6 +45,7 @@ public: | ||||||
| 
 | 
 | ||||||
|     virtual void die() override; |     virtual void die() override; | ||||||
| 
 | 
 | ||||||
|  |     void did_receive_headers(Badge<Download>, Download&); | ||||||
|     void did_finish_download(Badge<Download>, Download&, bool success); |     void did_finish_download(Badge<Download>, Download&, bool success); | ||||||
|     void did_progress_download(Badge<Download>, Download&); |     void did_progress_download(Badge<Download>, Download&); | ||||||
|     void did_request_certificates(Badge<Download>, Download&); |     void did_request_certificates(Badge<Download>, Download&); | ||||||
|  | @ -54,11 +55,9 @@ private: | ||||||
|     virtual OwnPtr<Messages::ProtocolServer::IsSupportedProtocolResponse> handle(const Messages::ProtocolServer::IsSupportedProtocol&) override; |     virtual OwnPtr<Messages::ProtocolServer::IsSupportedProtocolResponse> handle(const Messages::ProtocolServer::IsSupportedProtocol&) override; | ||||||
|     virtual OwnPtr<Messages::ProtocolServer::StartDownloadResponse> handle(const Messages::ProtocolServer::StartDownload&) override; |     virtual OwnPtr<Messages::ProtocolServer::StartDownloadResponse> handle(const Messages::ProtocolServer::StartDownload&) override; | ||||||
|     virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override; |     virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override; | ||||||
|     virtual OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> handle(const Messages::ProtocolServer::DisownSharedBuffer&) override; |     virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&) override; | ||||||
|     virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&); |  | ||||||
| 
 | 
 | ||||||
|     HashMap<i32, OwnPtr<Download>> m_downloads; |     HashMap<i32, OwnPtr<Download>> m_downloads; | ||||||
|     HashMap<i32, RefPtr<AK::SharedBuffer>> m_shared_buffers; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -33,9 +33,10 @@ namespace ProtocolServer { | ||||||
| // FIXME: What about rollover?
 | // FIXME: What about rollover?
 | ||||||
| static i32 s_next_id = 1; | static i32 s_next_id = 1; | ||||||
| 
 | 
 | ||||||
| Download::Download(ClientConnection& client) | Download::Download(ClientConnection& client, NonnullOwnPtr<OutputFileStream>&& output_stream) | ||||||
|     : m_client(client) |     : m_client(client) | ||||||
|     , m_id(s_next_id++) |     , m_id(s_next_id++) | ||||||
|  |     , m_output_stream(move(output_stream)) | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -48,15 +49,10 @@ void Download::stop() | ||||||
|     m_client.did_finish_download({}, *this, false); |     m_client.did_finish_download({}, *this, false); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void Download::set_payload(const ByteBuffer& payload) |  | ||||||
| { |  | ||||||
|     m_payload = payload; |  | ||||||
|     m_total_size = payload.size(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void Download::set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers) | void Download::set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers) | ||||||
| { | { | ||||||
|     m_response_headers = response_headers; |     m_response_headers = response_headers; | ||||||
|  |     m_client.did_receive_headers({}, *this); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void Download::set_certificate(String, String) | void Download::set_certificate(String, String) | ||||||
|  |  | ||||||
|  | @ -26,8 +26,9 @@ | ||||||
| 
 | 
 | ||||||
| #pragma once | #pragma once | ||||||
| 
 | 
 | ||||||
| #include <AK/ByteBuffer.h> | #include <AK/FileStream.h> | ||||||
| #include <AK/HashMap.h> | #include <AK/HashMap.h> | ||||||
|  | #include <AK/NonnullOwnPtr.h> | ||||||
| #include <AK/Optional.h> | #include <AK/Optional.h> | ||||||
| #include <AK/RefCounted.h> | #include <AK/RefCounted.h> | ||||||
| #include <AK/URL.h> | #include <AK/URL.h> | ||||||
|  | @ -45,30 +46,35 @@ public: | ||||||
|     Optional<u32> status_code() const { return m_status_code; } |     Optional<u32> status_code() const { return m_status_code; } | ||||||
|     Optional<u32> total_size() const { return m_total_size; } |     Optional<u32> total_size() const { return m_total_size; } | ||||||
|     size_t downloaded_size() const { return m_downloaded_size; } |     size_t downloaded_size() const { return m_downloaded_size; } | ||||||
|     const ByteBuffer& payload() const { return m_payload; } |  | ||||||
|     const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers() const { return m_response_headers; } |     const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers() const { return m_response_headers; } | ||||||
| 
 | 
 | ||||||
|     void stop(); |     void stop(); | ||||||
|     virtual void set_certificate(String, String); |     virtual void set_certificate(String, String); | ||||||
| 
 | 
 | ||||||
|  |     // FIXME: Want Badge<Protocol>, but can't make one from HttpProtocol, etc.
 | ||||||
|  |     void set_download_fd(int fd) { m_download_fd = fd; } | ||||||
|  |     int download_fd() const { return m_download_fd; } | ||||||
|  | 
 | ||||||
| protected: | protected: | ||||||
|     explicit Download(ClientConnection&); |     explicit Download(ClientConnection&, NonnullOwnPtr<OutputFileStream>&&); | ||||||
| 
 | 
 | ||||||
|     void did_finish(bool success); |     void did_finish(bool success); | ||||||
|     void did_progress(Optional<u32> total_size, u32 downloaded_size); |     void did_progress(Optional<u32> total_size, u32 downloaded_size); | ||||||
|     void set_status_code(u32 status_code) { m_status_code = status_code; } |     void set_status_code(u32 status_code) { m_status_code = status_code; } | ||||||
|     void did_request_certificates(); |     void did_request_certificates(); | ||||||
|     void set_payload(const ByteBuffer&); |  | ||||||
|     void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&); |     void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&); | ||||||
|  |     void set_downloaded_size(size_t size) { m_downloaded_size = size; } | ||||||
|  |     const OutputFileStream& output_stream() const { return *m_output_stream; } | ||||||
| 
 | 
 | ||||||
| private: | private: | ||||||
|     ClientConnection& m_client; |     ClientConnection& m_client; | ||||||
|     i32 m_id { 0 }; |     i32 m_id { 0 }; | ||||||
|  |     int m_download_fd { -1 }; // Passed to client.
 | ||||||
|     URL m_url; |     URL m_url; | ||||||
|     Optional<u32> m_status_code; |     Optional<u32> m_status_code; | ||||||
|     Optional<u32> m_total_size {}; |     Optional<u32> m_total_size {}; | ||||||
|     size_t m_downloaded_size { 0 }; |     size_t m_downloaded_size { 0 }; | ||||||
|     ByteBuffer m_payload; |     NonnullOwnPtr<OutputFileStream> m_output_stream; | ||||||
|     HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers; |     HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -30,13 +30,13 @@ | ||||||
| 
 | 
 | ||||||
| namespace ProtocolServer { | namespace ProtocolServer { | ||||||
| 
 | 
 | ||||||
| GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job) | GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) | ||||||
|     : Download(client) |     : Download(client, move(output_stream)) | ||||||
|     , m_job(job) |     , m_job(job) | ||||||
| { | { | ||||||
|     m_job->on_finish = [this](bool success) { |     m_job->on_finish = [this](bool success) { | ||||||
|         if (auto* response = m_job->response()) { |         if (auto* response = m_job->response()) { | ||||||
|             set_payload(response->payload()); |             set_downloaded_size(this->output_stream().size()); | ||||||
|             if (!response->meta().is_empty()) { |             if (!response->meta().is_empty()) { | ||||||
|                 HashMap<String, String, CaseInsensitiveStringTraits> headers; |                 HashMap<String, String, CaseInsensitiveStringTraits> headers; | ||||||
|                 headers.set("meta", response->meta()); |                 headers.set("meta", response->meta()); | ||||||
|  | @ -76,9 +76,9 @@ GeminiDownload::~GeminiDownload() | ||||||
|     m_job->shutdown(); |     m_job->shutdown(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job) | NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) | ||||||
| { | { | ||||||
|     return adopt_own(*new GeminiDownload(client, move(job))); |     return adopt_own(*new GeminiDownload(client, move(job), move(output_stream))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -36,10 +36,10 @@ namespace ProtocolServer { | ||||||
| class GeminiDownload final : public Download { | class GeminiDownload final : public Download { | ||||||
| public: | public: | ||||||
|     virtual ~GeminiDownload() override; |     virtual ~GeminiDownload() override; | ||||||
|     static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>); |     static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&); | ||||||
| 
 | 
 | ||||||
| private: | private: | ||||||
|     explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>); |     explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&); | ||||||
| 
 | 
 | ||||||
|     virtual void set_certificate(String certificate, String key) override; |     virtual void set_certificate(String certificate, String key) override; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -40,12 +40,22 @@ GeminiProtocol::~GeminiProtocol() | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap<String, String>&, const ByteBuffer&) | OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap<String, String>&, ReadonlyBytes) | ||||||
| { | { | ||||||
|     Gemini::GeminiRequest request; |     Gemini::GeminiRequest request; | ||||||
|     request.set_url(url); |     request.set_url(url); | ||||||
|     auto job = Gemini::GeminiJob::construct(request); | 
 | ||||||
|     auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job); |     int fd_pair[2] { 0 }; | ||||||
|  |     if (pipe(fd_pair) != 0) { | ||||||
|  |         auto saved_errno = errno; | ||||||
|  |         dbgln("Protocol: pipe() failed: {}", strerror(saved_errno)); | ||||||
|  |         return nullptr; | ||||||
|  |     } | ||||||
|  |     auto output_stream = make<OutputFileStream>(fd_pair[1]); | ||||||
|  |     output_stream->make_unbuffered(); | ||||||
|  |     auto job = Gemini::GeminiJob::construct(request, *output_stream); | ||||||
|  |     auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job, move(output_stream)); | ||||||
|  |     download->set_download_fd(fd_pair[0]); | ||||||
|     job->start(); |     job->start(); | ||||||
|     return download; |     return download; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -35,7 +35,7 @@ public: | ||||||
|     GeminiProtocol(); |     GeminiProtocol(); | ||||||
|     virtual ~GeminiProtocol() override; |     virtual ~GeminiProtocol() override; | ||||||
| 
 | 
 | ||||||
|     virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>&, const ByteBuffer& request_body) override; |     virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>&, ReadonlyBytes body) override; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -30,15 +30,21 @@ | ||||||
| 
 | 
 | ||||||
| namespace ProtocolServer { | namespace ProtocolServer { | ||||||
| 
 | 
 | ||||||
| HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job) | HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) | ||||||
|     : Download(client) |     : Download(client, move(output_stream)) | ||||||
|     , m_job(job) |     , m_job(job) | ||||||
| { | { | ||||||
|  |     m_job->on_headers_received = [this](auto& headers, auto response_code) { | ||||||
|  |         if (response_code.has_value()) | ||||||
|  |             set_status_code(response_code.value()); | ||||||
|  |         set_response_headers(headers); | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|     m_job->on_finish = [this](bool success) { |     m_job->on_finish = [this](bool success) { | ||||||
|         if (auto* response = m_job->response()) { |         if (auto* response = m_job->response()) { | ||||||
|             set_status_code(response->code()); |             set_status_code(response->code()); | ||||||
|             set_payload(response->payload()); |  | ||||||
|             set_response_headers(response->headers()); |             set_response_headers(response->headers()); | ||||||
|  |             set_downloaded_size(this->output_stream().size()); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // if we didn't know the total size, pretend that the download finished successfully
 |         // if we didn't know the total size, pretend that the download finished successfully
 | ||||||
|  | @ -60,9 +66,9 @@ HttpDownload::~HttpDownload() | ||||||
|     m_job->shutdown(); |     m_job->shutdown(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job) | NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) | ||||||
| { | { | ||||||
|     return adopt_own(*new HttpDownload(client, move(job))); |     return adopt_own(*new HttpDownload(client, move(job), move(output_stream))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -36,10 +36,10 @@ namespace ProtocolServer { | ||||||
| class HttpDownload final : public Download { | class HttpDownload final : public Download { | ||||||
| public: | public: | ||||||
|     virtual ~HttpDownload() override; |     virtual ~HttpDownload() override; | ||||||
|     static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>); |     static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&); | ||||||
| 
 | 
 | ||||||
| private: | private: | ||||||
|     explicit HttpDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>); |     explicit HttpDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&); | ||||||
| 
 | 
 | ||||||
|     NonnullRefPtr<HTTP::HttpJob> m_job; |     NonnullRefPtr<HTTP::HttpJob> m_job; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | @ -28,6 +28,7 @@ | ||||||
| #include <LibHTTP/HttpRequest.h> | #include <LibHTTP/HttpRequest.h> | ||||||
| #include <ProtocolServer/HttpDownload.h> | #include <ProtocolServer/HttpDownload.h> | ||||||
| #include <ProtocolServer/HttpProtocol.h> | #include <ProtocolServer/HttpProtocol.h> | ||||||
|  | #include <fcntl.h> | ||||||
| 
 | 
 | ||||||
| namespace ProtocolServer { | namespace ProtocolServer { | ||||||
| 
 | 
 | ||||||
|  | @ -40,7 +41,7 @@ HttpProtocol::~HttpProtocol() | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, const ByteBuffer& request_body) | OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, ReadonlyBytes body) | ||||||
| { | { | ||||||
|     HTTP::HttpRequest request; |     HTTP::HttpRequest request; | ||||||
|     if (method.equals_ignoring_case("post")) |     if (method.equals_ignoring_case("post")) | ||||||
|  | @ -49,9 +50,20 @@ OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const St | ||||||
|         request.set_method(HTTP::HttpRequest::Method::GET); |         request.set_method(HTTP::HttpRequest::Method::GET); | ||||||
|     request.set_url(url); |     request.set_url(url); | ||||||
|     request.set_headers(headers); |     request.set_headers(headers); | ||||||
|     request.set_body(request_body); |     request.set_body(body); | ||||||
|     auto job = HTTP::HttpJob::construct(request); | 
 | ||||||
|     auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job); |     int fd_pair[2] { 0 }; | ||||||
|  |     if (pipe(fd_pair) != 0) { | ||||||
|  |         auto saved_errno = errno; | ||||||
|  |         dbgln("Protocol: pipe() failed: {}", strerror(saved_errno)); | ||||||
|  |         return nullptr; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto output_stream = make<OutputFileStream>(fd_pair[1]); | ||||||
|  |     output_stream->make_unbuffered(); | ||||||
|  |     auto job = HTTP::HttpJob::construct(request, *output_stream); | ||||||
|  |     auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job, move(output_stream)); | ||||||
|  |     download->set_download_fd(fd_pair[0]); | ||||||
|     job->start(); |     job->start(); | ||||||
|     return download; |     return download; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -35,7 +35,7 @@ public: | ||||||
|     HttpProtocol(); |     HttpProtocol(); | ||||||
|     virtual ~HttpProtocol() override; |     virtual ~HttpProtocol() override; | ||||||
| 
 | 
 | ||||||
|     virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) override; |     virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) override; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -30,15 +30,21 @@ | ||||||
| 
 | 
 | ||||||
| namespace ProtocolServer { | namespace ProtocolServer { | ||||||
| 
 | 
 | ||||||
| HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job) | HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) | ||||||
|     : Download(client) |     : Download(client, move(output_stream)) | ||||||
|     , m_job(job) |     , m_job(job) | ||||||
| { | { | ||||||
|  |     m_job->on_headers_received = [this](auto& headers, auto response_code) { | ||||||
|  |         if (response_code.has_value()) | ||||||
|  |             set_status_code(response_code.value()); | ||||||
|  |         set_response_headers(headers); | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|     m_job->on_finish = [this](bool success) { |     m_job->on_finish = [this](bool success) { | ||||||
|         if (auto* response = m_job->response()) { |         if (auto* response = m_job->response()) { | ||||||
|             set_status_code(response->code()); |             set_status_code(response->code()); | ||||||
|             set_payload(response->payload()); |  | ||||||
|             set_response_headers(response->headers()); |             set_response_headers(response->headers()); | ||||||
|  |             set_downloaded_size(this->output_stream().size()); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // if we didn't know the total size, pretend that the download finished successfully
 |         // if we didn't know the total size, pretend that the download finished successfully
 | ||||||
|  | @ -68,9 +74,9 @@ HttpsDownload::~HttpsDownload() | ||||||
|     m_job->shutdown(); |     m_job->shutdown(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job) | NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) | ||||||
| { | { | ||||||
|     return adopt_own(*new HttpsDownload(client, move(job))); |     return adopt_own(*new HttpsDownload(client, move(job), move(output_stream))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -36,10 +36,10 @@ namespace ProtocolServer { | ||||||
| class HttpsDownload final : public Download { | class HttpsDownload final : public Download { | ||||||
| public: | public: | ||||||
|     virtual ~HttpsDownload() override; |     virtual ~HttpsDownload() override; | ||||||
|     static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>); |     static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&); | ||||||
| 
 | 
 | ||||||
| private: | private: | ||||||
|     explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>); |     explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&); | ||||||
| 
 | 
 | ||||||
|     virtual void set_certificate(String certificate, String key) override; |     virtual void set_certificate(String certificate, String key) override; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -40,7 +40,7 @@ HttpsProtocol::~HttpsProtocol() | ||||||
| { | { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, const ByteBuffer& request_body) | OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, ReadonlyBytes body) | ||||||
| { | { | ||||||
|     HTTP::HttpRequest request; |     HTTP::HttpRequest request; | ||||||
|     if (method.equals_ignoring_case("post")) |     if (method.equals_ignoring_case("post")) | ||||||
|  | @ -49,9 +49,19 @@ OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const S | ||||||
|         request.set_method(HTTP::HttpRequest::Method::GET); |         request.set_method(HTTP::HttpRequest::Method::GET); | ||||||
|     request.set_url(url); |     request.set_url(url); | ||||||
|     request.set_headers(headers); |     request.set_headers(headers); | ||||||
|     request.set_body(request_body); |     request.set_body(body); | ||||||
|     auto job = HTTP::HttpsJob::construct(request); | 
 | ||||||
|     auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job); |     int fd_pair[2] { 0 }; | ||||||
|  |     if (pipe(fd_pair) != 0) { | ||||||
|  |         auto saved_errno = errno; | ||||||
|  |         dbgln("Protocol: pipe() failed: {}", strerror(saved_errno)); | ||||||
|  |         return nullptr; | ||||||
|  |     } | ||||||
|  |     auto output_stream = make<OutputFileStream>(fd_pair[1]); | ||||||
|  |     output_stream->make_unbuffered(); | ||||||
|  |     auto job = HTTP::HttpsJob::construct(request, *output_stream); | ||||||
|  |     auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job, move(output_stream)); | ||||||
|  |     download->set_download_fd(fd_pair[0]); | ||||||
|     job->start(); |     job->start(); | ||||||
|     return download; |     return download; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -35,7 +35,7 @@ public: | ||||||
|     HttpsProtocol(); |     HttpsProtocol(); | ||||||
|     virtual ~HttpsProtocol() override; |     virtual ~HttpsProtocol() override; | ||||||
| 
 | 
 | ||||||
|     virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) override; |     virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) override; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -37,7 +37,7 @@ public: | ||||||
|     virtual ~Protocol(); |     virtual ~Protocol(); | ||||||
| 
 | 
 | ||||||
|     const String& name() const { return m_name; } |     const String& name() const { return m_name; } | ||||||
|     virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) = 0; |     virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) = 0; | ||||||
| 
 | 
 | ||||||
|     static Protocol* find_by_name(const String&); |     static Protocol* find_by_name(const String&); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -2,7 +2,8 @@ endpoint ProtocolClient = 13 | ||||||
| { | { | ||||||
|     // Download notifications |     // Download notifications | ||||||
|     DownloadProgress(i32 download_id, Optional<u32> total_size, u32 downloaded_size) =| |     DownloadProgress(i32 download_id, Optional<u32> total_size, u32 downloaded_size) =| | ||||||
|     DownloadFinished(i32 download_id, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, IPC::Dictionary response_headers) =| |     DownloadFinished(i32 download_id, bool success, u32 total_size) =| | ||||||
|  |     HeadersBecameAvailable(i32 download_id, IPC::Dictionary response_headers, Optional<u32> status_code) =| | ||||||
| 
 | 
 | ||||||
|     // Certificate requests |     // Certificate requests | ||||||
|     CertificateRequested(i32 download_id) => () |     CertificateRequested(i32 download_id) => () | ||||||
|  |  | ||||||
|  | @ -3,14 +3,11 @@ endpoint ProtocolServer = 9 | ||||||
|     // Basic protocol |     // Basic protocol | ||||||
|     Greet() => (i32 client_id) |     Greet() => (i32 client_id) | ||||||
| 
 | 
 | ||||||
|     // FIXME: It would be nice if the kernel provided a way to avoid this |  | ||||||
|     DisownSharedBuffer(i32 shbuf_id) => () |  | ||||||
| 
 |  | ||||||
|     // Test if a specific protocol is supported, e.g "http" |     // Test if a specific protocol is supported, e.g "http" | ||||||
|     IsSupportedProtocol(String protocol) => (bool supported) |     IsSupportedProtocol(String protocol) => (bool supported) | ||||||
| 
 | 
 | ||||||
|     // Download API |     // Download API | ||||||
|     StartDownload(String method, URL url, IPC::Dictionary request_headers, String request_body) => (i32 download_id) |     StartDownload(String method, URL url, IPC::Dictionary request_headers, ByteBuffer request_body) => (i32 download_id, IPC::File response_fd) | ||||||
|     StopDownload(i32 download_id) => (bool success) |     StopDownload(i32 download_id) => (bool success) | ||||||
|     SetCertificate(i32 download_id, String certificate, String key) => (bool success) |     SetCertificate(i32 download_id, String certificate, String key) => (bool success) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -35,7 +35,7 @@ | ||||||
| 
 | 
 | ||||||
| int main(int, char**) | int main(int, char**) | ||||||
| { | { | ||||||
|     if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr", nullptr) < 0) { |     if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr sendfd recvfd", nullptr) < 0) { | ||||||
|         perror("pledge"); |         perror("pledge"); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
|  | @ -45,7 +45,7 @@ int main(int, char**) | ||||||
| 
 | 
 | ||||||
|     Core::EventLoop event_loop; |     Core::EventLoop event_loop; | ||||||
|     // FIXME: Establish a connection to LookupServer and then drop "unix"?
 |     // FIXME: Establish a connection to LookupServer and then drop "unix"?
 | ||||||
|     if (pledge("stdio inet shared_buffer accept unix", nullptr) < 0) { |     if (pledge("stdio inet shared_buffer accept unix sendfd recvfd", nullptr) < 0) { | ||||||
|         perror("pledge"); |         perror("pledge"); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | @ -24,6 +24,7 @@ | ||||||
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
|  | #include <AK/FileStream.h> | ||||||
| #include <AK/GenericLexer.h> | #include <AK/GenericLexer.h> | ||||||
| #include <AK/LexicalPath.h> | #include <AK/LexicalPath.h> | ||||||
| #include <AK/NumberFormat.h> | #include <AK/NumberFormat.h> | ||||||
|  | @ -116,29 +117,50 @@ private: | ||||||
|     bool m_might_be_wrong { false }; |     bool m_might_be_wrong { false }; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| static void do_write(ReadonlyBytes payload) | template<typename ConditionT> | ||||||
| { | class ConditionalOutputFileStream final : public OutputFileStream { | ||||||
|     size_t length_remaining = payload.size(); | public: | ||||||
|     size_t length_written = 0; |     template<typename... Args> | ||||||
|     while (length_remaining > 0) { |     ConditionalOutputFileStream(ConditionT&& condition, Args... args) | ||||||
|         auto nwritten = fwrite(payload.data() + length_written, sizeof(char), length_remaining, stdout); |         : OutputFileStream(args...) | ||||||
|         if (nwritten > 0) { |         , m_condition(condition) | ||||||
|             length_remaining -= nwritten; |     { | ||||||
|             length_written += nwritten; |     } | ||||||
|             continue; |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         if (feof(stdout)) { |     ~ConditionalOutputFileStream() | ||||||
|             fprintf(stderr, "pro: unexpected eof while writing\n"); |     { | ||||||
|  |         if (!m_condition()) | ||||||
|             return; |             return; | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         if (ferror(stdout)) { |         if (!m_buffer.is_empty()) { | ||||||
|             fprintf(stderr, "pro: error while writing\n"); |             OutputFileStream::write(m_buffer); | ||||||
|             return; |             m_buffer.clear(); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | 
 | ||||||
|  | private: | ||||||
|  |     size_t write(ReadonlyBytes bytes) override | ||||||
|  |     { | ||||||
|  |         if (!m_condition()) { | ||||||
|  |         write_to_buffer:; | ||||||
|  |             m_buffer.append(bytes.data(), bytes.size()); | ||||||
|  |             return bytes.size(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if (!m_buffer.is_empty()) { | ||||||
|  |             auto size = OutputFileStream::write(m_buffer); | ||||||
|  |             m_buffer = m_buffer.slice(size, m_buffer.size() - size); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if (!m_buffer.is_empty()) | ||||||
|  |             goto write_to_buffer; | ||||||
|  | 
 | ||||||
|  |         return OutputFileStream::write(bytes); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ConditionT m_condition; | ||||||
|  |     ByteBuffer m_buffer; | ||||||
|  | }; | ||||||
| 
 | 
 | ||||||
| int main(int argc, char** argv) | int main(int argc, char** argv) | ||||||
| { | { | ||||||
|  | @ -195,6 +217,8 @@ int main(int argc, char** argv) | ||||||
|     timeval prev_time, current_time, time_diff; |     timeval prev_time, current_time, time_diff; | ||||||
|     gettimeofday(&prev_time, nullptr); |     gettimeofday(&prev_time, nullptr); | ||||||
| 
 | 
 | ||||||
|  |     bool received_actual_headers = false; | ||||||
|  | 
 | ||||||
|     download->on_progress = [&](Optional<u32> maybe_total_size, u32 downloaded_size) { |     download->on_progress = [&](Optional<u32> maybe_total_size, u32 downloaded_size) { | ||||||
|         fprintf(stderr, "\r\033[2K"); |         fprintf(stderr, "\r\033[2K"); | ||||||
|         if (maybe_total_size.has_value()) { |         if (maybe_total_size.has_value()) { | ||||||
|  | @ -215,10 +239,13 @@ int main(int argc, char** argv) | ||||||
|         previous_downloaded_size = downloaded_size; |         previous_downloaded_size = downloaded_size; | ||||||
|         prev_time = current_time; |         prev_time = current_time; | ||||||
|     }; |     }; | ||||||
|     download->on_finish = [&](bool success, auto payload, auto, auto& response_headers, auto) { | 
 | ||||||
|         fprintf(stderr, "\033]9;-1;\033\\"); |     if (save_at_provided_name) { | ||||||
|         fprintf(stderr, "\n"); |         download->on_headers_received = [&](auto& response_headers, auto status_code) { | ||||||
|         if (success && save_at_provided_name) { |             if (received_actual_headers) | ||||||
|  |                 return; | ||||||
|  |             dbg() << "Received headers! response code = " << status_code.value_or(0); | ||||||
|  |             received_actual_headers = true; // And not trailers!
 | ||||||
|             String output_name; |             String output_name; | ||||||
|             if (auto content_disposition = response_headers.get("Content-Disposition"); content_disposition.has_value()) { |             if (auto content_disposition = response_headers.get("Content-Disposition"); content_disposition.has_value()) { | ||||||
|                 auto& value = content_disposition.value(); |                 auto& value = content_disposition.value(); | ||||||
|  | @ -245,17 +272,26 @@ int main(int argc, char** argv) | ||||||
| 
 | 
 | ||||||
|             if (freopen(output_name.characters(), "w", stdout) == nullptr) { |             if (freopen(output_name.characters(), "w", stdout) == nullptr) { | ||||||
|                 perror("freopen"); |                 perror("freopen"); | ||||||
|                 success = false; // oops!
 |  | ||||||
|                 loop.quit(1); |                 loop.quit(1); | ||||||
|  |                 return; | ||||||
|             } |             } | ||||||
|         } |         }; | ||||||
|         if (success) |     } | ||||||
|             do_write(payload); |     download->on_finish = [&](bool success, auto) { | ||||||
|         else |         fprintf(stderr, "\033]9;-1;\033\\"); | ||||||
|  |         fprintf(stderr, "\n"); | ||||||
|  |         if (!success) | ||||||
|             fprintf(stderr, "Download failed :(\n"); |             fprintf(stderr, "Download failed :(\n"); | ||||||
|         loop.quit(0); |         loop.quit(0); | ||||||
|     }; |     }; | ||||||
|  | 
 | ||||||
|  |     auto output_stream = ConditionalOutputFileStream { [&] { return save_at_provided_name ? received_actual_headers : true; }, stdout }; | ||||||
|  |     download->stream_into(output_stream); | ||||||
|  | 
 | ||||||
|     dbgprintf("started download with id %d\n", download->id()); |     dbgprintf("started download with id %d\n", download->id()); | ||||||
| 
 | 
 | ||||||
|     return loop.exec(); |     auto rc = loop.exec(); | ||||||
|  |     // FIXME: This shouldn't be needed.
 | ||||||
|  |     fclose(stdout); | ||||||
|  |     return rc; | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 AnotherTest
						AnotherTest