1
Fork 0
mirror of https://github.com/RGBCube/serenity synced 2025-05-31 09:38:11 +00:00

ProtocolServer: Stream the downloaded data if possible

This patchset makes ProtocolServer stream the downloads to its client
(LibProtocol), and as such changes the download API; a possible
download lifecycle could be as such:
notation = client->server:'>', server->client:'<', pipe activity:'*'
```
> StartDownload(GET, url, headers, {})
< Response(0, fd 8)
* {data, 1024b}
< HeadersBecameAvailable(0, response_headers, 200)
< DownloadProgress(0, 4K, 1024)
* {data, 1024b}
* {data, 1024b}
< DownloadProgress(0, 4K, 2048)
* {data, 1024b}
< DownloadProgress(0, 4K, 1024)
< DownloadFinished(0, true, 4K)
```

Since managing the received file descriptor is a pain, LibProtocol
implements `Download::stream_into(OutputStream)`, which can be used to
stream the download into any given output stream (be it a file, or
memory, or writing stuff with a delay, etc.).
Also, as some of the users of this API require all the downloaded data
upfront, LibProtocol also implements `set_should_buffer_all_input()`,
which causes the download instance to buffer all the data until the
download is complete, and to call the `on_buffered_download_finish`
hook.
This commit is contained in:
AnotherTest 2020-12-26 17:14:12 +03:30 committed by Andreas Kling
parent 36d642ee75
commit 4a2da10e38
55 changed files with 528 additions and 235 deletions

View file

@ -29,6 +29,7 @@
#include <AK/SharedBuffer.h> #include <AK/SharedBuffer.h>
#include <AK/StringBuilder.h> #include <AK/StringBuilder.h>
#include <LibCore/File.h> #include <LibCore/File.h>
#include <LibCore/FileStream.h>
#include <LibCore/StandardPaths.h> #include <LibCore/StandardPaths.h>
#include <LibDesktop/Launcher.h> #include <LibDesktop/Launcher.h>
#include <LibGUI/BoxLayout.h> #include <LibGUI/BoxLayout.h>
@ -61,9 +62,19 @@ DownloadWidget::DownloadWidget(const URL& url)
m_download->on_progress = [this](Optional<u32> total_size, u32 downloaded_size) { m_download->on_progress = [this](Optional<u32> total_size, u32 downloaded_size) {
did_progress(total_size.value(), downloaded_size); did_progress(total_size.value(), downloaded_size);
}; };
m_download->on_finish = [this](bool success, auto payload, auto payload_storage, auto& response_headers, auto) {
did_finish(success, payload, payload_storage, response_headers); {
}; auto file_or_error = Core::File::open(m_destination_path, Core::IODevice::WriteOnly);
if (file_or_error.is_error()) {
GUI::MessageBox::show(window(), String::formatted("Cannot open {} for writing", m_destination_path), "Download failed", GUI::MessageBox::Type::Error);
window()->close();
return;
}
m_output_file_stream = make<Core::OutputFileStream>(*file_or_error.value());
}
m_download->on_finish = [this](bool success, auto) { did_finish(success); };
m_download->stream_into(*m_output_file_stream);
set_fill_with_background_color(true); set_fill_with_background_color(true);
auto& layout = set_layout<GUI::VerticalBoxLayout>(); auto& layout = set_layout<GUI::VerticalBoxLayout>();
@ -149,7 +160,7 @@ void DownloadWidget::did_progress(Optional<u32> total_size, u32 downloaded_size)
} }
} }
void DownloadWidget::did_finish(bool success, [[maybe_unused]] ReadonlyBytes payload, [[maybe_unused]] RefPtr<SharedBuffer> payload_storage, [[maybe_unused]] const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers) void DownloadWidget::did_finish(bool success)
{ {
dbg() << "did_finish, success=" << success; dbg() << "did_finish, success=" << success;
@ -166,17 +177,6 @@ void DownloadWidget::did_finish(bool success, [[maybe_unused]] ReadonlyBytes pay
window()->close(); window()->close();
return; return;
} }
auto file_or_error = Core::File::open(m_destination_path, Core::IODevice::WriteOnly);
if (file_or_error.is_error()) {
GUI::MessageBox::show(window(), String::formatted("Cannot open {} for writing", m_destination_path), "Download failed", GUI::MessageBox::Type::Error);
window()->close();
return;
}
auto& file = *file_or_error.value();
bool write_success = file.write(payload.data(), payload.size());
ASSERT(write_success);
} }
} }

View file

@ -28,6 +28,7 @@
#include <AK/URL.h> #include <AK/URL.h>
#include <LibCore/ElapsedTimer.h> #include <LibCore/ElapsedTimer.h>
#include <LibCore/FileStream.h>
#include <LibGUI/ProgressBar.h> #include <LibGUI/ProgressBar.h>
#include <LibGUI/Widget.h> #include <LibGUI/Widget.h>
#include <LibProtocol/Download.h> #include <LibProtocol/Download.h>
@ -44,7 +45,7 @@ private:
explicit DownloadWidget(const URL&); explicit DownloadWidget(const URL&);
void did_progress(Optional<u32> total_size, u32 downloaded_size); void did_progress(Optional<u32> total_size, u32 downloaded_size);
void did_finish(bool success, ReadonlyBytes payload, RefPtr<SharedBuffer> payload_storage, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers); void did_finish(bool success);
URL m_url; URL m_url;
String m_destination_path; String m_destination_path;
@ -53,6 +54,7 @@ private:
RefPtr<GUI::Label> m_progress_label; RefPtr<GUI::Label> m_progress_label;
RefPtr<GUI::Button> m_cancel_button; RefPtr<GUI::Button> m_cancel_button;
RefPtr<GUI::Button> m_close_button; RefPtr<GUI::Button> m_close_button;
OwnPtr<Core::OutputFileStream> m_output_file_stream;
Core::ElapsedTimer m_elapsed_timer; Core::ElapsedTimer m_elapsed_timer;
}; };

View file

@ -68,7 +68,7 @@ int main(int argc, char** argv)
return 1; return 1;
} }
if (pledge("stdio shared_buffer accept unix cpath rpath wpath fattr", nullptr) < 0) { if (pledge("stdio shared_buffer accept unix cpath rpath wpath fattr sendfd recvfd", nullptr) < 0) {
perror("pledge"); perror("pledge");
return 1; return 1;
} }
@ -86,7 +86,7 @@ int main(int argc, char** argv)
Web::ResourceLoader::the(); Web::ResourceLoader::the();
// FIXME: Once there is a standalone Download Manager, we can drop the "unix" pledge. // FIXME: Once there is a standalone Download Manager, we can drop the "unix" pledge.
if (pledge("stdio shared_buffer accept unix cpath rpath wpath", nullptr) < 0) { if (pledge("stdio shared_buffer accept unix cpath rpath wpath sendfd recvfd", nullptr) < 0) {
perror("pledge"); perror("pledge");
return 1; return 1;
} }

View file

@ -32,7 +32,8 @@
namespace Core { namespace Core {
NetworkJob::NetworkJob() NetworkJob::NetworkJob(OutputStream& output_stream)
: m_output_stream(output_stream)
{ {
} }

View file

@ -27,6 +27,7 @@
#pragma once #pragma once
#include <AK/Function.h> #include <AK/Function.h>
#include <AK/Stream.h>
#include <LibCore/Object.h> #include <LibCore/Object.h>
namespace Core { namespace Core {
@ -43,6 +44,8 @@ public:
}; };
virtual ~NetworkJob() override; virtual ~NetworkJob() override;
// Could fire twice, after Headers and after Trailers!
Function<void(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)> on_headers_received;
Function<void(bool success)> on_finish; Function<void(bool success)> on_finish;
Function<void(Optional<u32>, u32)> on_progress; Function<void(Optional<u32>, u32)> on_progress;
@ -62,13 +65,16 @@ public:
} }
protected: protected:
NetworkJob(); NetworkJob(OutputStream&);
void did_finish(NonnullRefPtr<NetworkResponse>&&); void did_finish(NonnullRefPtr<NetworkResponse>&&);
void did_fail(Error); void did_fail(Error);
void did_progress(Optional<u32> total_size, u32 downloaded); void did_progress(Optional<u32> total_size, u32 downloaded);
size_t do_write(ReadonlyBytes bytes) { return m_output_stream.write(bytes); }
private: private:
RefPtr<NetworkResponse> m_response; RefPtr<NetworkResponse> m_response;
OutputStream& m_output_stream;
Error m_error { Error::None }; Error m_error { Error::None };
}; };

View file

@ -28,8 +28,7 @@
namespace Core { namespace Core {
NetworkResponse::NetworkResponse(ByteBuffer&& payload) NetworkResponse::NetworkResponse()
: m_payload(payload)
{ {
} }

View file

@ -36,13 +36,11 @@ public:
virtual ~NetworkResponse(); virtual ~NetworkResponse();
bool is_error() const { return m_error; } bool is_error() const { return m_error; }
const ByteBuffer& payload() const { return m_payload; }
protected: protected:
explicit NetworkResponse(ByteBuffer&&); explicit NetworkResponse();
bool m_error { false }; bool m_error { false };
ByteBuffer m_payload;
}; };
} }

View file

@ -142,9 +142,9 @@ bool GeminiJob::eof() const
return m_socket->eof(); return m_socket->eof();
} }
bool GeminiJob::write(const ByteBuffer& data) bool GeminiJob::write(ReadonlyBytes bytes)
{ {
return m_socket->write(data); return m_socket->write(bytes);
} }
} }

View file

@ -37,8 +37,8 @@ namespace Gemini {
class GeminiJob final : public Job { class GeminiJob final : public Job {
C_OBJECT(GeminiJob) C_OBJECT(GeminiJob)
public: public:
explicit GeminiJob(const GeminiRequest& request, const Vector<Certificate>* override_certificates = nullptr) explicit GeminiJob(const GeminiRequest& request, OutputStream& output_stream, const Vector<Certificate>* override_certificates = nullptr)
: Job(request) : Job(request, output_stream)
, m_override_ca_certificates(override_certificates) , m_override_ca_certificates(override_certificates)
{ {
} }
@ -61,7 +61,7 @@ protected:
virtual bool can_read() const override; virtual bool can_read() const override;
virtual ByteBuffer receive(size_t) override; virtual ByteBuffer receive(size_t) override;
virtual bool eof() const override; virtual bool eof() const override;
virtual bool write(const ByteBuffer&) override; virtual bool write(ReadonlyBytes) override;
virtual bool is_established() const override { return m_socket->is_established(); } virtual bool is_established() const override { return m_socket->is_established(); }
virtual bool should_fail_on_empty_payload() const override { return false; } virtual bool should_fail_on_empty_payload() const override { return false; }
virtual void read_while_data_available(Function<IterationDecision()>) override; virtual void read_while_data_available(Function<IterationDecision()>) override;

View file

@ -28,9 +28,8 @@
namespace Gemini { namespace Gemini {
GeminiResponse::GeminiResponse(int status, String meta, ByteBuffer&& payload) GeminiResponse::GeminiResponse(int status, String meta)
: Core::NetworkResponse(move(payload)) : m_status(status)
, m_status(status)
, m_meta(meta) , m_meta(meta)
{ {
} }

View file

@ -34,16 +34,16 @@ namespace Gemini {
class GeminiResponse : public Core::NetworkResponse { class GeminiResponse : public Core::NetworkResponse {
public: public:
virtual ~GeminiResponse() override; virtual ~GeminiResponse() override;
static NonnullRefPtr<GeminiResponse> create(int status, String meta, ByteBuffer&& payload) static NonnullRefPtr<GeminiResponse> create(int status, String meta)
{ {
return adopt(*new GeminiResponse(status, meta, move(payload))); return adopt(*new GeminiResponse(status, meta));
} }
int status() const { return m_status; } int status() const { return m_status; }
String meta() const { return m_meta; } String meta() const { return m_meta; }
private: private:
GeminiResponse(int status, String, ByteBuffer&&); GeminiResponse(int status, String);
int m_status { 0 }; int m_status { 0 };
String m_meta; String m_meta;

View file

@ -33,8 +33,9 @@
namespace Gemini { namespace Gemini {
Job::Job(const GeminiRequest& request) Job::Job(const GeminiRequest& request, OutputStream& output_stream)
: m_request(request) : Core::NetworkJob(output_stream)
, m_request(request)
{ {
} }
@ -42,6 +43,23 @@ Job::~Job()
{ {
} }
void Job::flush_received_buffers()
{
for (size_t i = 0; i < m_received_buffers.size(); ++i) {
auto& payload = m_received_buffers[i];
auto written = do_write(payload);
m_received_size -= written;
if (written == payload.size()) {
// FIXME: Make this a take-first-friendly object?
m_received_buffers.take_first();
continue;
}
ASSERT(written < payload.size());
payload = payload.slice(written, payload.size() - written);
return;
}
}
void Job::on_socket_connected() void Job::on_socket_connected()
{ {
register_on_ready_to_write([this] { register_on_ready_to_write([this] {
@ -126,6 +144,7 @@ void Job::on_socket_connected()
m_received_buffers.append(payload); m_received_buffers.append(payload);
m_received_size += payload.size(); m_received_size += payload.size();
flush_received_buffers();
deferred_invoke([this](auto&) { did_progress({}, m_received_size); }); deferred_invoke([this](auto&) { did_progress({}, m_received_size); });
@ -144,15 +163,17 @@ void Job::on_socket_connected()
void Job::finish_up() void Job::finish_up()
{ {
m_state = State::Finished; m_state = State::Finished;
auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size); flush_received_buffers();
u8* flat_ptr = flattened_buffer.data(); if (m_received_size != 0) {
for (auto& received_buffer : m_received_buffers) { // FIXME: What do we do? ignore it?
memcpy(flat_ptr, received_buffer.data(), received_buffer.size()); // "Transmission failed" is not strictly correct, but let's roll with it for now.
flat_ptr += received_buffer.size(); deferred_invoke([this](auto&) {
did_fail(Error::TransmissionFailed);
});
return;
} }
m_received_buffers.clear();
auto response = GeminiResponse::create(m_status, m_meta, move(flattened_buffer)); auto response = GeminiResponse::create(m_status, m_meta);
deferred_invoke([this, response](auto&) { deferred_invoke([this, response](auto&) {
did_finish(move(response)); did_finish(move(response));
}); });

View file

@ -36,7 +36,7 @@ namespace Gemini {
class Job : public Core::NetworkJob { class Job : public Core::NetworkJob {
public: public:
explicit Job(const GeminiRequest&); explicit Job(const GeminiRequest&, OutputStream&);
virtual ~Job() override; virtual ~Job() override;
virtual void start() override = 0; virtual void start() override = 0;
@ -48,6 +48,7 @@ public:
protected: protected:
void finish_up(); void finish_up();
void on_socket_connected(); void on_socket_connected();
void flush_received_buffers();
virtual void register_on_ready_to_read(Function<void()>) = 0; virtual void register_on_ready_to_read(Function<void()>) = 0;
virtual void register_on_ready_to_write(Function<void()>) = 0; virtual void register_on_ready_to_write(Function<void()>) = 0;
virtual bool can_read_line() const = 0; virtual bool can_read_line() const = 0;
@ -55,7 +56,7 @@ protected:
virtual bool can_read() const = 0; virtual bool can_read() const = 0;
virtual ByteBuffer receive(size_t) = 0; virtual ByteBuffer receive(size_t) = 0;
virtual bool eof() const = 0; virtual bool eof() const = 0;
virtual bool write(const ByteBuffer&) = 0; virtual bool write(ReadonlyBytes) = 0;
virtual bool is_established() const = 0; virtual bool is_established() const = 0;
virtual bool should_fail_on_empty_payload() const { return false; } virtual bool should_fail_on_empty_payload() const { return false; }
virtual void read_while_data_available(Function<IterationDecision()> read) { read(); }; virtual void read_while_data_available(Function<IterationDecision()> read) { read(); };
@ -70,7 +71,7 @@ protected:
State m_state { State::InStatus }; State m_state { State::InStatus };
int m_status { -1 }; int m_status { -1 };
String m_meta; String m_meta;
Vector<ByteBuffer> m_received_buffers; Vector<ByteBuffer, 2> m_received_buffers;
size_t m_received_size { 0 }; size_t m_received_size { 0 };
bool m_sent_data { false }; bool m_sent_data { false };
bool m_should_have_payload { false }; bool m_should_have_payload { false };

View file

@ -98,9 +98,9 @@ bool HttpJob::eof() const
return m_socket->eof(); return m_socket->eof();
} }
bool HttpJob::write(const ByteBuffer& data) bool HttpJob::write(ReadonlyBytes bytes)
{ {
return m_socket->write(data); return m_socket->write(bytes);
} }
} }

View file

@ -38,8 +38,8 @@ namespace HTTP {
class HttpJob final : public Job { class HttpJob final : public Job {
C_OBJECT(HttpJob) C_OBJECT(HttpJob)
public: public:
explicit HttpJob(const HttpRequest& request) explicit HttpJob(const HttpRequest& request, OutputStream& output_stream)
: Job(request) : Job(request, output_stream)
{ {
} }
@ -59,7 +59,7 @@ protected:
virtual bool can_read() const override; virtual bool can_read() const override;
virtual ByteBuffer receive(size_t) override; virtual ByteBuffer receive(size_t) override;
virtual bool eof() const override; virtual bool eof() const override;
virtual bool write(const ByteBuffer&) override; virtual bool write(ReadonlyBytes) override;
virtual bool is_established() const override { return true; } virtual bool is_established() const override { return true; }
private: private:

View file

@ -71,11 +71,12 @@ ByteBuffer HttpRequest::to_raw_request() const
builder.append(header.value); builder.append(header.value);
builder.append("\r\n"); builder.append("\r\n");
} }
builder.append("Connection: close\r\n\r\n"); builder.append("Connection: close\r\n");
if (!m_body.is_empty()) { if (!m_body.is_empty()) {
builder.appendff("Content-Length: {}\r\n\r\n", m_body.size());
builder.append((const char*)m_body.data(), m_body.size()); builder.append((const char*)m_body.data(), m_body.size());
builder.append("\r\n");
} }
builder.append("\r\n");
return builder.to_byte_buffer(); return builder.to_byte_buffer();
} }

View file

@ -62,7 +62,8 @@ public:
void set_method(Method method) { m_method = method; } void set_method(Method method) { m_method = method; }
const ByteBuffer& body() const { return m_body; } const ByteBuffer& body() const { return m_body; }
void set_body(const ByteBuffer& body) { m_body = body; } void set_body(ReadonlyBytes body) { m_body = ByteBuffer::copy(body); }
void set_body(ByteBuffer&& body) { m_body = move(body); }
String method_name() const; String method_name() const;
ByteBuffer to_raw_request() const; ByteBuffer to_raw_request() const;

View file

@ -28,9 +28,8 @@
namespace HTTP { namespace HTTP {
HttpResponse::HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers, ByteBuffer&& payload) HttpResponse::HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers)
: Core::NetworkResponse(move(payload)) : m_code(code)
, m_code(code)
, m_headers(move(headers)) , m_headers(move(headers))
{ {
} }

View file

@ -35,16 +35,16 @@ namespace HTTP {
class HttpResponse : public Core::NetworkResponse { class HttpResponse : public Core::NetworkResponse {
public: public:
virtual ~HttpResponse() override; virtual ~HttpResponse() override;
static NonnullRefPtr<HttpResponse> create(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers, ByteBuffer&& payload) static NonnullRefPtr<HttpResponse> create(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers)
{ {
return adopt(*new HttpResponse(code, move(headers), move(payload))); return adopt(*new HttpResponse(code, move(headers)));
} }
int code() const { return m_code; } int code() const { return m_code; }
const HashMap<String, String, CaseInsensitiveStringTraits>& headers() const { return m_headers; } const HashMap<String, String, CaseInsensitiveStringTraits>& headers() const { return m_headers; }
private: private:
HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&&, ByteBuffer&&); HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&&);
int m_code { 0 }; int m_code { 0 };
HashMap<String, String, CaseInsensitiveStringTraits> m_headers; HashMap<String, String, CaseInsensitiveStringTraits> m_headers;

View file

@ -143,7 +143,7 @@ bool HttpsJob::eof() const
return m_socket->eof(); return m_socket->eof();
} }
bool HttpsJob::write(const ByteBuffer& data) bool HttpsJob::write(ReadonlyBytes data)
{ {
return m_socket->write(data); return m_socket->write(data);
} }

View file

@ -38,8 +38,8 @@ namespace HTTP {
class HttpsJob final : public Job { class HttpsJob final : public Job {
C_OBJECT(HttpsJob) C_OBJECT(HttpsJob)
public: public:
explicit HttpsJob(const HttpRequest& request, const Vector<Certificate>* override_certs = nullptr) explicit HttpsJob(const HttpRequest& request, OutputStream& output_stream, const Vector<Certificate>* override_certs = nullptr)
: Job(request) : Job(request, output_stream)
, m_override_ca_certificates(override_certs) , m_override_ca_certificates(override_certs)
{ {
} }
@ -62,7 +62,7 @@ protected:
virtual bool can_read() const override; virtual bool can_read() const override;
virtual ByteBuffer receive(size_t) override; virtual ByteBuffer receive(size_t) override;
virtual bool eof() const override; virtual bool eof() const override;
virtual bool write(const ByteBuffer&) override; virtual bool write(ReadonlyBytes) override;
virtual bool is_established() const override { return m_socket->is_established(); } virtual bool is_established() const override { return m_socket->is_established(); }
virtual bool should_fail_on_empty_payload() const override { return false; } virtual bool should_fail_on_empty_payload() const override { return false; }
virtual void read_while_data_available(Function<IterationDecision()>) override; virtual void read_while_data_available(Function<IterationDecision()>) override;

View file

@ -68,8 +68,9 @@ static ByteBuffer handle_content_encoding(const ByteBuffer& buf, const String& c
return buf; return buf;
} }
Job::Job(const HttpRequest& request) Job::Job(const HttpRequest& request, OutputStream& output_stream)
: m_request(request) : Core::NetworkJob(output_stream)
, m_request(request)
{ {
} }
@ -77,6 +78,35 @@ Job::~Job()
{ {
} }
void Job::flush_received_buffers()
{
if (!m_can_stream_response || m_buffered_size == 0)
return;
#ifdef JOB_DEBUG
dbg() << "Job: Flushing received buffers: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers";
#endif
for (size_t i = 0; i < m_received_buffers.size(); ++i) {
auto& payload = m_received_buffers[i];
auto written = do_write(payload);
m_buffered_size -= written;
if (written == payload.size()) {
// FIXME: Make this a take-first-friendly object?
m_received_buffers.take_first();
--i;
continue;
}
ASSERT(written < payload.size());
payload = payload.slice(written, payload.size() - written);
#ifdef JOB_DEBUG
dbg() << "Job: Flushing received buffers done: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers";
#endif
return;
}
#ifdef JOB_DEBUG
dbg() << "Job: Flushing received buffers done: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers";
#endif
}
void Job::on_socket_connected() void Job::on_socket_connected()
{ {
register_on_ready_to_write([&] { register_on_ready_to_write([&] {
@ -135,6 +165,8 @@ void Job::on_socket_connected()
if (m_state == State::Trailers) { if (m_state == State::Trailers) {
return finish_up(); return finish_up();
} else { } else {
if (on_headers_received)
on_headers_received(m_headers, m_code > 0 ? m_code : Optional<u32> {});
m_state = State::InBody; m_state = State::InBody;
} }
return; return;
@ -163,6 +195,13 @@ void Job::on_socket_connected()
} }
auto value = line.substring(name.length() + 2, line.length() - name.length() - 2); auto value = line.substring(name.length() + 2, line.length() - name.length() - 2);
m_headers.set(name, value); m_headers.set(name, value);
if (name.equals_ignoring_case("Content-Encoding")) {
// Assume that any content-encoding means that we can't decode it as a stream :(
#ifdef JOB_DEBUG
dbg() << "Content-Encoding " << value << " detected, cannot stream output :(";
#endif
m_can_stream_response = false;
}
#ifdef JOB_DEBUG #ifdef JOB_DEBUG
dbg() << "Job: [" << name << "] = '" << value << "'"; dbg() << "Job: [" << name << "] = '" << value << "'";
#endif #endif
@ -252,7 +291,9 @@ void Job::on_socket_connected()
} }
m_received_buffers.append(payload); m_received_buffers.append(payload);
m_buffered_size += payload.size();
m_received_size += payload.size(); m_received_size += payload.size();
flush_received_buffers();
if (m_current_chunk_remaining_size.has_value()) { if (m_current_chunk_remaining_size.has_value()) {
auto size = m_current_chunk_remaining_size.value() - payload.size(); auto size = m_current_chunk_remaining_size.value() - payload.size();
@ -313,20 +354,37 @@ void Job::on_socket_connected()
void Job::finish_up() void Job::finish_up()
{ {
m_state = State::Finished; m_state = State::Finished;
auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size); if (!m_can_stream_response) {
u8* flat_ptr = flattened_buffer.data(); auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size);
for (auto& received_buffer : m_received_buffers) { u8* flat_ptr = flattened_buffer.data();
memcpy(flat_ptr, received_buffer.data(), received_buffer.size()); for (auto& received_buffer : m_received_buffers) {
flat_ptr += received_buffer.size(); memcpy(flat_ptr, received_buffer.data(), received_buffer.size());
} flat_ptr += received_buffer.size();
m_received_buffers.clear(); }
m_received_buffers.clear();
auto content_encoding = m_headers.get("Content-Encoding"); // For the time being, we cannot stream stuff with content-encoding set to _anything_.
if (content_encoding.has_value()) { auto content_encoding = m_headers.get("Content-Encoding");
flattened_buffer = handle_content_encoding(flattened_buffer, content_encoding.value()); if (content_encoding.has_value()) {
flattened_buffer = handle_content_encoding(flattened_buffer, content_encoding.value());
}
m_buffered_size = flattened_buffer.size();
m_received_buffers.append(move(flattened_buffer));
m_can_stream_response = true;
} }
auto response = HttpResponse::create(m_code, move(m_headers), move(flattened_buffer)); flush_received_buffers();
if (m_buffered_size != 0) {
// FIXME: What do we do? ignore it?
// "Transmission failed" is not strictly correct, but let's roll with it for now.
deferred_invoke([this](auto&) {
did_fail(Error::TransmissionFailed);
});
return;
}
auto response = HttpResponse::create(m_code, move(m_headers));
deferred_invoke([this, response](auto&) { deferred_invoke([this, response](auto&) {
did_finish(move(response)); did_finish(move(response));
}); });

View file

@ -26,6 +26,7 @@
#pragma once #pragma once
#include <AK/FileStream.h>
#include <AK/HashMap.h> #include <AK/HashMap.h>
#include <AK/Optional.h> #include <AK/Optional.h>
#include <LibCore/NetworkJob.h> #include <LibCore/NetworkJob.h>
@ -37,7 +38,7 @@ namespace HTTP {
class Job : public Core::NetworkJob { class Job : public Core::NetworkJob {
public: public:
explicit Job(const HttpRequest&); explicit Job(const HttpRequest&, OutputStream&);
virtual ~Job() override; virtual ~Job() override;
virtual void start() override = 0; virtual void start() override = 0;
@ -49,6 +50,7 @@ public:
protected: protected:
void finish_up(); void finish_up();
void on_socket_connected(); void on_socket_connected();
void flush_received_buffers();
virtual void register_on_ready_to_read(Function<void()>) = 0; virtual void register_on_ready_to_read(Function<void()>) = 0;
virtual void register_on_ready_to_write(Function<void()>) = 0; virtual void register_on_ready_to_write(Function<void()>) = 0;
virtual bool can_read_line() const = 0; virtual bool can_read_line() const = 0;
@ -56,7 +58,7 @@ protected:
virtual bool can_read() const = 0; virtual bool can_read() const = 0;
virtual ByteBuffer receive(size_t) = 0; virtual ByteBuffer receive(size_t) = 0;
virtual bool eof() const = 0; virtual bool eof() const = 0;
virtual bool write(const ByteBuffer&) = 0; virtual bool write(ReadonlyBytes) = 0;
virtual bool is_established() const = 0; virtual bool is_established() const = 0;
virtual bool should_fail_on_empty_payload() const { return true; } virtual bool should_fail_on_empty_payload() const { return true; }
virtual void read_while_data_available(Function<IterationDecision()> read) { read(); }; virtual void read_while_data_available(Function<IterationDecision()> read) { read(); };
@ -73,11 +75,13 @@ protected:
State m_state { State::InStatus }; State m_state { State::InStatus };
int m_code { -1 }; int m_code { -1 };
HashMap<String, String, CaseInsensitiveStringTraits> m_headers; HashMap<String, String, CaseInsensitiveStringTraits> m_headers;
Vector<ByteBuffer> m_received_buffers; Vector<ByteBuffer, 2> m_received_buffers;
size_t m_buffered_size { 0 };
size_t m_received_size { 0 }; size_t m_received_size { 0 };
bool m_sent_data { 0 }; bool m_sent_data { 0 };
Optional<ssize_t> m_current_chunk_remaining_size; Optional<ssize_t> m_current_chunk_remaining_size;
Optional<size_t> m_current_chunk_total_size; Optional<size_t> m_current_chunk_total_size;
bool m_can_stream_response { true };
}; };
} }

View file

@ -24,6 +24,7 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/ */
#include <AK/FileStream.h>
#include <AK/SharedBuffer.h> #include <AK/SharedBuffer.h>
#include <LibProtocol/Client.h> #include <LibProtocol/Client.h>
#include <LibProtocol/Download.h> #include <LibProtocol/Download.h>
@ -47,16 +48,20 @@ bool Client::is_supported_protocol(const String& protocol)
return send_sync<Messages::ProtocolServer::IsSupportedProtocol>(protocol)->supported(); return send_sync<Messages::ProtocolServer::IsSupportedProtocol>(protocol)->supported();
} }
RefPtr<Download> Client::start_download(const String& method, const String& url, const HashMap<String, String>& request_headers, const ByteBuffer& request_body) template<typename RequestHashMapTraits>
RefPtr<Download> Client::start_download(const String& method, const String& url, const HashMap<String, String, RequestHashMapTraits>& request_headers, ReadonlyBytes request_body)
{ {
IPC::Dictionary header_dictionary; IPC::Dictionary header_dictionary;
for (auto& it : request_headers) for (auto& it : request_headers)
header_dictionary.add(it.key, it.value); header_dictionary.add(it.key, it.value);
i32 download_id = send_sync<Messages::ProtocolServer::StartDownload>(method, url, header_dictionary, String::copy(request_body))->download_id(); auto response = send_sync<Messages::ProtocolServer::StartDownload>(method, url, header_dictionary, ByteBuffer::copy(request_body));
if (download_id < 0) auto download_id = response->download_id();
auto response_fd = response->response_fd().fd();
if (download_id < 0 || response_fd < 0)
return nullptr; return nullptr;
auto download = Download::create_from_id({}, *this, download_id); auto download = Download::create_from_id({}, *this, download_id);
download->set_download_fd({}, response_fd);
m_downloads.set(download_id, download); m_downloads.set(download_id, download);
return download; return download;
} }
@ -79,9 +84,8 @@ void Client::handle(const Messages::ProtocolClient::DownloadFinished& message)
{ {
RefPtr<Download> download; RefPtr<Download> download;
if ((download = m_downloads.get(message.download_id()).value_or(nullptr))) { if ((download = m_downloads.get(message.download_id()).value_or(nullptr))) {
download->did_finish({}, message.success(), message.status_code(), message.total_size(), message.shbuf_id(), message.response_headers()); download->did_finish({}, message.success(), message.total_size());
} }
send_sync<Messages::ProtocolServer::DisownSharedBuffer>(message.shbuf_id());
m_downloads.remove(message.download_id()); m_downloads.remove(message.download_id());
} }
@ -92,6 +96,15 @@ void Client::handle(const Messages::ProtocolClient::DownloadProgress& message)
} }
} }
void Client::handle(const Messages::ProtocolClient::HeadersBecameAvailable& message)
{
if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) {
HashMap<String, String, CaseInsensitiveStringTraits> headers;
message.response_headers().for_each_entry([&](auto& name, auto& value) { headers.set(name, value); });
download->did_receive_headers({}, headers, message.status_code());
}
}
OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(const Messages::ProtocolClient::CertificateRequested& message) OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(const Messages::ProtocolClient::CertificateRequested& message)
{ {
if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) { if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) {
@ -102,3 +115,6 @@ OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(co
} }
} }
template RefPtr<Protocol::Download> Protocol::Client::start_download(const String& method, const String& url, const HashMap<String, String>& request_headers, ReadonlyBytes request_body);
template RefPtr<Protocol::Download> Protocol::Client::start_download(const String& method, const String& url, const HashMap<String, String, CaseInsensitiveStringTraits>& request_headers, ReadonlyBytes request_body);

View file

@ -44,7 +44,8 @@ public:
virtual void handshake() override; virtual void handshake() override;
bool is_supported_protocol(const String&); bool is_supported_protocol(const String&);
RefPtr<Download> start_download(const String& method, const String& url, const HashMap<String, String>& request_headers = {}, const ByteBuffer& request_body = {}); template<typename RequestHashMapTraits = Traits<String>>
RefPtr<Download> start_download(const String& method, const String& url, const HashMap<String, String, RequestHashMapTraits>& request_headers = {}, ReadonlyBytes request_body = {});
bool stop_download(Badge<Download>, Download&); bool stop_download(Badge<Download>, Download&);
bool set_certificate(Badge<Download>, Download&, String, String); bool set_certificate(Badge<Download>, Download&, String, String);
@ -55,6 +56,7 @@ private:
virtual void handle(const Messages::ProtocolClient::DownloadProgress&) override; virtual void handle(const Messages::ProtocolClient::DownloadProgress&) override;
virtual void handle(const Messages::ProtocolClient::DownloadFinished&) override; virtual void handle(const Messages::ProtocolClient::DownloadFinished&) override;
virtual OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> handle(const Messages::ProtocolClient::CertificateRequested&) override; virtual OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> handle(const Messages::ProtocolClient::CertificateRequested&) override;
virtual void handle(const Messages::ProtocolClient::HeadersBecameAvailable&) override;
HashMap<i32, RefPtr<Download>> m_downloads; HashMap<i32, RefPtr<Download>> m_downloads;
}; };

View file

@ -41,25 +41,81 @@ bool Download::stop()
return m_client->stop_download({}, *this); return m_client->stop_download({}, *this);
} }
void Download::did_finish(Badge<Client>, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers) void Download::stream_into(OutputStream& stream)
{
ASSERT(!m_internal_stream_data);
auto notifier = Core::Notifier::construct(fd(), Core::Notifier::Read);
m_internal_stream_data = make<InternalStreamData>(fd());
m_internal_stream_data->read_notifier = notifier;
auto user_on_finish = move(on_finish);
on_finish = [this](auto success, auto total_size) {
m_internal_stream_data->success = success;
m_internal_stream_data->total_size = total_size;
m_internal_stream_data->download_done = true;
};
notifier->on_ready_to_read = [this, &stream, user_on_finish = move(user_on_finish)] {
constexpr size_t buffer_size = 1 * KiB;
static char buf[buffer_size];
auto nread = m_internal_stream_data->read_stream.read({ buf, buffer_size });
if (!stream.write_or_error({ buf, nread })) {
// FIXME: What do we do here?
TODO();
}
if (m_internal_stream_data->read_stream.eof() || (m_internal_stream_data->download_done && !m_internal_stream_data->success)) {
m_internal_stream_data->read_notifier->close();
user_on_finish(m_internal_stream_data->success, m_internal_stream_data->total_size);
} else {
m_internal_stream_data->read_stream.handle_any_error();
}
};
}
void Download::set_should_buffer_all_input(bool value)
{
if (m_should_buffer_all_input == value)
return;
if (m_internal_buffered_data && !value) {
m_internal_buffered_data = nullptr;
m_should_buffer_all_input = false;
return;
}
ASSERT(!m_internal_stream_data);
ASSERT(!m_internal_buffered_data);
ASSERT(on_buffered_download_finish); // Not having this set makes no sense.
m_internal_buffered_data = make<InternalBufferedData>(fd());
m_should_buffer_all_input = true;
on_headers_received = [this](auto& headers, auto response_code) {
m_internal_buffered_data->response_headers = headers;
m_internal_buffered_data->response_code = move(response_code);
};
on_finish = [this](auto success, u32 total_size) {
auto output_buffer = m_internal_buffered_data->payload_stream.copy_into_contiguous_buffer();
on_buffered_download_finish(
success,
total_size,
m_internal_buffered_data->response_headers,
m_internal_buffered_data->response_code,
output_buffer);
};
stream_into(m_internal_buffered_data->payload_stream);
}
void Download::did_finish(Badge<Client>, bool success, u32 total_size)
{ {
if (!on_finish) if (!on_finish)
return; return;
ReadonlyBytes payload; on_finish(success, total_size);
RefPtr<SharedBuffer> shared_buffer;
if (success && shbuf_id != -1) {
shared_buffer = SharedBuffer::create_from_shbuf_id(shbuf_id);
payload = { shared_buffer->data<void>(), total_size };
}
// FIXME: It's a bit silly that we copy the response headers here just so we can move them into a HashMap with different traits.
HashMap<String, String, CaseInsensitiveStringTraits> caseless_response_headers;
response_headers.for_each_entry([&](auto& name, auto& value) {
caseless_response_headers.set(name, value);
});
on_finish(success, payload, move(shared_buffer), caseless_response_headers, status_code);
} }
void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size) void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size)
@ -68,6 +124,12 @@ void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloa
on_progress(total_size, downloaded_size); on_progress(total_size, downloaded_size);
} }
void Download::did_receive_headers(Badge<Client>, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)
{
if (on_headers_received)
on_headers_received(response_headers, response_code);
}
void Download::did_request_certificates(Badge<Client>) void Download::did_request_certificates(Badge<Client>)
{ {
if (on_certificate_requested) { if (on_certificate_requested) {

View file

@ -28,10 +28,13 @@
#include <AK/Badge.h> #include <AK/Badge.h>
#include <AK/ByteBuffer.h> #include <AK/ByteBuffer.h>
#include <AK/FileStream.h>
#include <AK/Function.h> #include <AK/Function.h>
#include <AK/MemoryStream.h>
#include <AK/RefCounted.h> #include <AK/RefCounted.h>
#include <AK/String.h> #include <AK/String.h>
#include <AK/WeakPtr.h> #include <AK/WeakPtr.h>
#include <LibCore/Notifier.h>
#include <LibIPC/Forward.h> #include <LibIPC/Forward.h>
namespace Protocol { namespace Protocol {
@ -51,20 +54,65 @@ public:
} }
int id() const { return m_download_id; } int id() const { return m_download_id; }
int fd() const { return m_fd; }
bool stop(); bool stop();
Function<void(bool success, ReadonlyBytes payload, RefPtr<SharedBuffer> payload_storage, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> status_code)> on_finish; void stream_into(OutputStream&);
bool should_buffer_all_input() const { return m_should_buffer_all_input; }
/// Note: Will override `on_finish', and `on_headers_received', and expects `on_buffered_download_finish' to be set!
void set_should_buffer_all_input(bool);
/// Note: Must be set before `set_should_buffer_all_input(true)`.
Function<void(bool success, u32 total_size, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code, ReadonlyBytes payload)> on_buffered_download_finish;
Function<void(bool success, u32 total_size)> on_finish;
Function<void(Optional<u32> total_size, u32 downloaded_size)> on_progress; Function<void(Optional<u32> total_size, u32 downloaded_size)> on_progress;
Function<void(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)> on_headers_received;
Function<CertificateAndKey()> on_certificate_requested; Function<CertificateAndKey()> on_certificate_requested;
void did_finish(Badge<Client>, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers); void did_finish(Badge<Client>, bool success, u32 total_size);
void did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size); void did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size);
void did_receive_headers(Badge<Client>, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code);
void did_request_certificates(Badge<Client>); void did_request_certificates(Badge<Client>);
RefPtr<Core::Notifier>& write_notifier(Badge<Client>) { return m_write_notifier; }
void set_download_fd(Badge<Client>, int fd) { m_fd = fd; }
private: private:
explicit Download(Client&, i32 download_id); explicit Download(Client&, i32 download_id);
WeakPtr<Client> m_client; WeakPtr<Client> m_client;
int m_download_id { -1 }; int m_download_id { -1 };
RefPtr<Core::Notifier> m_write_notifier;
int m_fd { -1 };
bool m_should_buffer_all_input { false };
struct InternalBufferedData {
InternalBufferedData(int fd)
: read_stream(fd)
{
}
InputFileStream read_stream;
DuplexMemoryStream payload_stream;
HashMap<String, String, CaseInsensitiveStringTraits> response_headers;
Optional<u32> response_code;
};
struct InternalStreamData {
InternalStreamData(int fd)
: read_stream(fd)
{
}
InputFileStream read_stream;
RefPtr<Core::Notifier> read_notifier;
bool success;
u32 total_size { 0 };
bool download_done { false };
};
OwnPtr<InternalBufferedData> m_internal_buffered_data;
OwnPtr<InternalStreamData> m_internal_stream_data;
}; };
} }

View file

@ -92,10 +92,10 @@ void XMLHttpRequest::send()
// we need to make ResourceLoader give us more detailed updates than just "done" and "error". // we need to make ResourceLoader give us more detailed updates than just "done" and "error".
ResourceLoader::the().load( ResourceLoader::the().load(
m_window->document().complete_url(m_url), m_window->document().complete_url(m_url),
[weak_this = make_weak_ptr()](auto& data, auto&) { [weak_this = make_weak_ptr()](auto data, auto&) {
if (!weak_this) if (!weak_this)
return; return;
const_cast<XMLHttpRequest&>(*weak_this).m_response = data; const_cast<XMLHttpRequest&>(*weak_this).m_response = ByteBuffer::copy(data);
const_cast<XMLHttpRequest&>(*weak_this).set_ready_state(ReadyState::Done); const_cast<XMLHttpRequest&>(*weak_this).set_ready_state(ReadyState::Done);
const_cast<XMLHttpRequest&>(*weak_this).dispatch_event(DOM::Event::create(HTML::EventNames::load)); const_cast<XMLHttpRequest&>(*weak_this).dispatch_event(DOM::Event::create(HTML::EventNames::load));
}, },

View file

@ -128,7 +128,7 @@ void HTMLScriptElement::prepare_script(Badge<HTMLDocumentParser>)
// FIXME: This load should be made asynchronous and the parser should spin an event loop etc. // FIXME: This load should be made asynchronous and the parser should spin an event loop etc.
ResourceLoader::the().load_sync( ResourceLoader::the().load_sync(
url, url,
[this, url](auto& data, auto&) { [this, url](auto data, auto&) {
if (data.is_null()) { if (data.is_null()) {
dbg() << "HTMLScriptElement: Failed to load " << url; dbg() << "HTMLScriptElement: Failed to load " << url;
return; return;

View file

@ -171,6 +171,7 @@ bool FrameLoader::load(const LoadRequest& request, Type type)
return true; return true;
if (url.protocol() == "http" || url.protocol() == "https") { if (url.protocol() == "http" || url.protocol() == "https") {
#if 0
URL favicon_url; URL favicon_url;
favicon_url.set_protocol(url.protocol()); favicon_url.set_protocol(url.protocol());
favicon_url.set_host(url.host()); favicon_url.set_host(url.host());
@ -191,6 +192,7 @@ bool FrameLoader::load(const LoadRequest& request, Type type)
if (auto* page = frame().page()) if (auto* page = frame().page())
page->client().page_did_change_favicon(*bitmap); page->client().page_did_change_favicon(*bitmap);
}); });
#endif
} }
return true; return true;

View file

@ -84,10 +84,10 @@ static String mime_type_from_content_type(const String& content_type)
return content_type; return content_type;
} }
void Resource::did_load(Badge<ResourceLoader>, const ByteBuffer& data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers) void Resource::did_load(Badge<ResourceLoader>, ReadonlyBytes data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers)
{ {
ASSERT(!m_loaded); ASSERT(!m_loaded);
m_encoded_data = data; m_encoded_data = ByteBuffer::copy(data);
m_response_headers = headers; m_response_headers = headers;
m_loaded = true; m_loaded = true;

View file

@ -77,7 +77,7 @@ public:
void for_each_client(Function<void(ResourceClient&)>); void for_each_client(Function<void(ResourceClient&)>);
void did_load(Badge<ResourceLoader>, const ByteBuffer& data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers); void did_load(Badge<ResourceLoader>, ReadonlyBytes data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers);
void did_fail(Badge<ResourceLoader>, const String& error); void did_fail(Badge<ResourceLoader>, const String& error);
protected: protected:

View file

@ -53,13 +53,13 @@ ResourceLoader::ResourceLoader()
{ {
} }
void ResourceLoader::load_sync(const URL& url, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback) void ResourceLoader::load_sync(const URL& url, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
{ {
Core::EventLoop loop; Core::EventLoop loop;
load( load(
url, url,
[&](auto& data, auto& response_headers) { [&](auto data, auto& response_headers) {
success_callback(data, response_headers); success_callback(data, response_headers);
loop.quit(0); loop.quit(0);
}, },
@ -97,7 +97,7 @@ RefPtr<Resource> ResourceLoader::load_resource(Resource::Type type, const LoadRe
load( load(
request, request,
[=](auto& data, auto& headers) { [=](auto data, auto& headers) {
const_cast<Resource&>(*resource).did_load({}, data, headers); const_cast<Resource&>(*resource).did_load({}, data, headers);
}, },
[=](auto& error) { [=](auto& error) {
@ -107,7 +107,7 @@ RefPtr<Resource> ResourceLoader::load_resource(Resource::Type type, const LoadRe
return resource; return resource;
} }
void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback) void ResourceLoader::load(const LoadRequest& request, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
{ {
auto& url = request.url(); auto& url = request.url();
if (is_port_blocked(url.port())) { if (is_port_blocked(url.port())) {
@ -170,7 +170,12 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu
error_callback("Failed to initiate load"); error_callback("Failed to initiate load");
return; return;
} }
download->on_finish = [this, success_callback = move(success_callback), error_callback = move(error_callback)](bool success, ReadonlyBytes payload, auto, auto& response_headers, auto status_code) { download->on_buffered_download_finish = [this, success_callback = move(success_callback), error_callback = move(error_callback), download](bool success, auto, auto& response_headers, auto status_code, ReadonlyBytes payload) {
if (status_code.has_value() && status_code.value() >= 400 && status_code.value() <= 499) {
if (error_callback)
error_callback(String::format("HTTP error (%u)", status_code.value()));
return;
}
--m_pending_loads; --m_pending_loads;
if (on_load_counter_change) if (on_load_counter_change)
on_load_counter_change(); on_load_counter_change();
@ -179,13 +184,9 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu
error_callback("HTTP load failed"); error_callback("HTTP load failed");
return; return;
} }
if (status_code.has_value() && status_code.value() >= 400 && status_code.value() <= 499) { success_callback(payload, response_headers);
if (error_callback)
error_callback(String::format("HTTP error (%u)", status_code.value()));
return;
}
success_callback(ByteBuffer::copy(payload.data(), payload.size()), response_headers);
}; };
download->set_should_buffer_all_input(true);
download->on_certificate_requested = []() -> Protocol::Download::CertificateAndKey { download->on_certificate_requested = []() -> Protocol::Download::CertificateAndKey {
return {}; return {};
}; };
@ -199,7 +200,7 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu
error_callback(String::format("Protocol not implemented: %s", url.protocol().characters())); error_callback(String::format("Protocol not implemented: %s", url.protocol().characters()));
} }
void ResourceLoader::load(const URL& url, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback) void ResourceLoader::load(const URL& url, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
{ {
LoadRequest request; LoadRequest request;
request.set_url(url); request.set_url(url);

View file

@ -44,9 +44,9 @@ public:
RefPtr<Resource> load_resource(Resource::Type, const LoadRequest&); RefPtr<Resource> load_resource(Resource::Type, const LoadRequest&);
void load(const LoadRequest&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr); void load(const LoadRequest&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
void load(const URL&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr); void load(const URL&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
void load_sync(const URL&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr); void load_sync(const URL&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
Function<void()> on_load_counter_change; Function<void()> on_load_counter_change;

View file

@ -62,16 +62,17 @@ OwnPtr<Messages::ProtocolServer::StartDownloadResponse> ClientConnection::handle
{ {
URL url(message.url()); URL url(message.url());
if (!url.is_valid()) if (!url.is_valid())
return make<Messages::ProtocolServer::StartDownloadResponse>(-1); return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1);
auto* protocol = Protocol::find_by_name(url.protocol()); auto* protocol = Protocol::find_by_name(url.protocol());
if (!protocol) if (!protocol)
return make<Messages::ProtocolServer::StartDownloadResponse>(-1); return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1);
auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body().to_byte_buffer()); auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body());
if (!download) if (!download)
return make<Messages::ProtocolServer::StartDownloadResponse>(-1); return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1);
auto id = download->id(); auto id = download->id();
auto fd = download->download_fd();
m_downloads.set(id, move(download)); m_downloads.set(id, move(download));
return make<Messages::ProtocolServer::StartDownloadResponse>(id); return make<Messages::ProtocolServer::StartDownloadResponse>(id, fd);
} }
OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle(const Messages::ProtocolServer::StopDownload& message) OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle(const Messages::ProtocolServer::StopDownload& message)
@ -86,22 +87,20 @@ OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle(
return make<Messages::ProtocolServer::StopDownloadResponse>(success); return make<Messages::ProtocolServer::StopDownloadResponse>(success);
} }
void ClientConnection::did_finish_download(Badge<Download>, Download& download, bool success) void ClientConnection::did_receive_headers(Badge<Download>, Download& download)
{ {
RefPtr<SharedBuffer> buffer;
if (success && download.payload().size() > 0 && !download.payload().is_null()) {
buffer = SharedBuffer::create_with_size(download.payload().size());
memcpy(buffer->data<void>(), download.payload().data(), download.payload().size());
buffer->seal();
buffer->share_with(client_pid());
m_shared_buffers.set(buffer->shbuf_id(), buffer);
}
ASSERT(download.total_size().has_value());
IPC::Dictionary response_headers; IPC::Dictionary response_headers;
for (auto& it : download.response_headers()) for (auto& it : download.response_headers())
response_headers.add(it.key, it.value); response_headers.add(it.key, it.value);
post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.status_code(), download.total_size().value(), buffer ? buffer->shbuf_id() : -1, response_headers));
post_message(Messages::ProtocolClient::HeadersBecameAvailable(download.id(), move(response_headers), download.status_code()));
}
void ClientConnection::did_finish_download(Badge<Download>, Download& download, bool success)
{
ASSERT(download.total_size().has_value());
post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size().value()));
m_downloads.remove(download.id()); m_downloads.remove(download.id());
} }
@ -121,12 +120,6 @@ OwnPtr<Messages::ProtocolServer::GreetResponse> ClientConnection::handle(const M
return make<Messages::ProtocolServer::GreetResponse>(client_id()); return make<Messages::ProtocolServer::GreetResponse>(client_id());
} }
OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> ClientConnection::handle(const Messages::ProtocolServer::DisownSharedBuffer& message)
{
m_shared_buffers.remove(message.shbuf_id());
return make<Messages::ProtocolServer::DisownSharedBufferResponse>();
}
OwnPtr<Messages::ProtocolServer::SetCertificateResponse> ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message) OwnPtr<Messages::ProtocolServer::SetCertificateResponse> ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message)
{ {
auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr)); auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr));

View file

@ -45,6 +45,7 @@ public:
virtual void die() override; virtual void die() override;
void did_receive_headers(Badge<Download>, Download&);
void did_finish_download(Badge<Download>, Download&, bool success); void did_finish_download(Badge<Download>, Download&, bool success);
void did_progress_download(Badge<Download>, Download&); void did_progress_download(Badge<Download>, Download&);
void did_request_certificates(Badge<Download>, Download&); void did_request_certificates(Badge<Download>, Download&);
@ -54,11 +55,9 @@ private:
virtual OwnPtr<Messages::ProtocolServer::IsSupportedProtocolResponse> handle(const Messages::ProtocolServer::IsSupportedProtocol&) override; virtual OwnPtr<Messages::ProtocolServer::IsSupportedProtocolResponse> handle(const Messages::ProtocolServer::IsSupportedProtocol&) override;
virtual OwnPtr<Messages::ProtocolServer::StartDownloadResponse> handle(const Messages::ProtocolServer::StartDownload&) override; virtual OwnPtr<Messages::ProtocolServer::StartDownloadResponse> handle(const Messages::ProtocolServer::StartDownload&) override;
virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override; virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override;
virtual OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> handle(const Messages::ProtocolServer::DisownSharedBuffer&) override; virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&) override;
virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&);
HashMap<i32, OwnPtr<Download>> m_downloads; HashMap<i32, OwnPtr<Download>> m_downloads;
HashMap<i32, RefPtr<AK::SharedBuffer>> m_shared_buffers;
}; };
} }

View file

@ -33,9 +33,10 @@ namespace ProtocolServer {
// FIXME: What about rollover? // FIXME: What about rollover?
static i32 s_next_id = 1; static i32 s_next_id = 1;
Download::Download(ClientConnection& client) Download::Download(ClientConnection& client, NonnullOwnPtr<OutputFileStream>&& output_stream)
: m_client(client) : m_client(client)
, m_id(s_next_id++) , m_id(s_next_id++)
, m_output_stream(move(output_stream))
{ {
} }
@ -48,15 +49,10 @@ void Download::stop()
m_client.did_finish_download({}, *this, false); m_client.did_finish_download({}, *this, false);
} }
void Download::set_payload(const ByteBuffer& payload)
{
m_payload = payload;
m_total_size = payload.size();
}
void Download::set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers) void Download::set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)
{ {
m_response_headers = response_headers; m_response_headers = response_headers;
m_client.did_receive_headers({}, *this);
} }
void Download::set_certificate(String, String) void Download::set_certificate(String, String)

View file

@ -26,8 +26,9 @@
#pragma once #pragma once
#include <AK/ByteBuffer.h> #include <AK/FileStream.h>
#include <AK/HashMap.h> #include <AK/HashMap.h>
#include <AK/NonnullOwnPtr.h>
#include <AK/Optional.h> #include <AK/Optional.h>
#include <AK/RefCounted.h> #include <AK/RefCounted.h>
#include <AK/URL.h> #include <AK/URL.h>
@ -45,30 +46,35 @@ public:
Optional<u32> status_code() const { return m_status_code; } Optional<u32> status_code() const { return m_status_code; }
Optional<u32> total_size() const { return m_total_size; } Optional<u32> total_size() const { return m_total_size; }
size_t downloaded_size() const { return m_downloaded_size; } size_t downloaded_size() const { return m_downloaded_size; }
const ByteBuffer& payload() const { return m_payload; }
const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers() const { return m_response_headers; } const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers() const { return m_response_headers; }
void stop(); void stop();
virtual void set_certificate(String, String); virtual void set_certificate(String, String);
// FIXME: Want Badge<Protocol>, but can't make one from HttpProtocol, etc.
void set_download_fd(int fd) { m_download_fd = fd; }
int download_fd() const { return m_download_fd; }
protected: protected:
explicit Download(ClientConnection&); explicit Download(ClientConnection&, NonnullOwnPtr<OutputFileStream>&&);
void did_finish(bool success); void did_finish(bool success);
void did_progress(Optional<u32> total_size, u32 downloaded_size); void did_progress(Optional<u32> total_size, u32 downloaded_size);
void set_status_code(u32 status_code) { m_status_code = status_code; } void set_status_code(u32 status_code) { m_status_code = status_code; }
void did_request_certificates(); void did_request_certificates();
void set_payload(const ByteBuffer&);
void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&); void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&);
void set_downloaded_size(size_t size) { m_downloaded_size = size; }
const OutputFileStream& output_stream() const { return *m_output_stream; }
private: private:
ClientConnection& m_client; ClientConnection& m_client;
i32 m_id { 0 }; i32 m_id { 0 };
int m_download_fd { -1 }; // Passed to client.
URL m_url; URL m_url;
Optional<u32> m_status_code; Optional<u32> m_status_code;
Optional<u32> m_total_size {}; Optional<u32> m_total_size {};
size_t m_downloaded_size { 0 }; size_t m_downloaded_size { 0 };
ByteBuffer m_payload; NonnullOwnPtr<OutputFileStream> m_output_stream;
HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers; HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers;
}; };

View file

@ -30,13 +30,13 @@
namespace ProtocolServer { namespace ProtocolServer {
GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job) GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
: Download(client) : Download(client, move(output_stream))
, m_job(job) , m_job(job)
{ {
m_job->on_finish = [this](bool success) { m_job->on_finish = [this](bool success) {
if (auto* response = m_job->response()) { if (auto* response = m_job->response()) {
set_payload(response->payload()); set_downloaded_size(this->output_stream().size());
if (!response->meta().is_empty()) { if (!response->meta().is_empty()) {
HashMap<String, String, CaseInsensitiveStringTraits> headers; HashMap<String, String, CaseInsensitiveStringTraits> headers;
headers.set("meta", response->meta()); headers.set("meta", response->meta());
@ -76,9 +76,9 @@ GeminiDownload::~GeminiDownload()
m_job->shutdown(); m_job->shutdown();
} }
NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job) NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
{ {
return adopt_own(*new GeminiDownload(client, move(job))); return adopt_own(*new GeminiDownload(client, move(job), move(output_stream)));
} }
} }

View file

@ -36,10 +36,10 @@ namespace ProtocolServer {
class GeminiDownload final : public Download { class GeminiDownload final : public Download {
public: public:
virtual ~GeminiDownload() override; virtual ~GeminiDownload() override;
static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>); static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&);
private: private:
explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>); explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&);
virtual void set_certificate(String certificate, String key) override; virtual void set_certificate(String certificate, String key) override;

View file

@ -40,12 +40,22 @@ GeminiProtocol::~GeminiProtocol()
{ {
} }
OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap<String, String>&, const ByteBuffer&) OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap<String, String>&, ReadonlyBytes)
{ {
Gemini::GeminiRequest request; Gemini::GeminiRequest request;
request.set_url(url); request.set_url(url);
auto job = Gemini::GeminiJob::construct(request);
auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job); int fd_pair[2] { 0 };
if (pipe(fd_pair) != 0) {
auto saved_errno = errno;
dbgln("Protocol: pipe() failed: {}", strerror(saved_errno));
return nullptr;
}
auto output_stream = make<OutputFileStream>(fd_pair[1]);
output_stream->make_unbuffered();
auto job = Gemini::GeminiJob::construct(request, *output_stream);
auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job, move(output_stream));
download->set_download_fd(fd_pair[0]);
job->start(); job->start();
return download; return download;
} }

View file

@ -35,7 +35,7 @@ public:
GeminiProtocol(); GeminiProtocol();
virtual ~GeminiProtocol() override; virtual ~GeminiProtocol() override;
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>&, const ByteBuffer& request_body) override; virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>&, ReadonlyBytes body) override;
}; };
} }

View file

@ -30,15 +30,21 @@
namespace ProtocolServer { namespace ProtocolServer {
HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job) HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
: Download(client) : Download(client, move(output_stream))
, m_job(job) , m_job(job)
{ {
m_job->on_headers_received = [this](auto& headers, auto response_code) {
if (response_code.has_value())
set_status_code(response_code.value());
set_response_headers(headers);
};
m_job->on_finish = [this](bool success) { m_job->on_finish = [this](bool success) {
if (auto* response = m_job->response()) { if (auto* response = m_job->response()) {
set_status_code(response->code()); set_status_code(response->code());
set_payload(response->payload());
set_response_headers(response->headers()); set_response_headers(response->headers());
set_downloaded_size(this->output_stream().size());
} }
// if we didn't know the total size, pretend that the download finished successfully // if we didn't know the total size, pretend that the download finished successfully
@ -60,9 +66,9 @@ HttpDownload::~HttpDownload()
m_job->shutdown(); m_job->shutdown();
} }
NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job) NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
{ {
return adopt_own(*new HttpDownload(client, move(job))); return adopt_own(*new HttpDownload(client, move(job), move(output_stream)));
} }
} }

View file

@ -36,10 +36,10 @@ namespace ProtocolServer {
class HttpDownload final : public Download { class HttpDownload final : public Download {
public: public:
virtual ~HttpDownload() override; virtual ~HttpDownload() override;
static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>); static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&);
private: private:
explicit HttpDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>); explicit HttpDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&);
NonnullRefPtr<HTTP::HttpJob> m_job; NonnullRefPtr<HTTP::HttpJob> m_job;
}; };

View file

@ -28,6 +28,7 @@
#include <LibHTTP/HttpRequest.h> #include <LibHTTP/HttpRequest.h>
#include <ProtocolServer/HttpDownload.h> #include <ProtocolServer/HttpDownload.h>
#include <ProtocolServer/HttpProtocol.h> #include <ProtocolServer/HttpProtocol.h>
#include <fcntl.h>
namespace ProtocolServer { namespace ProtocolServer {
@ -40,7 +41,7 @@ HttpProtocol::~HttpProtocol()
{ {
} }
OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, const ByteBuffer& request_body) OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, ReadonlyBytes body)
{ {
HTTP::HttpRequest request; HTTP::HttpRequest request;
if (method.equals_ignoring_case("post")) if (method.equals_ignoring_case("post"))
@ -49,9 +50,20 @@ OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const St
request.set_method(HTTP::HttpRequest::Method::GET); request.set_method(HTTP::HttpRequest::Method::GET);
request.set_url(url); request.set_url(url);
request.set_headers(headers); request.set_headers(headers);
request.set_body(request_body); request.set_body(body);
auto job = HTTP::HttpJob::construct(request);
auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job); int fd_pair[2] { 0 };
if (pipe(fd_pair) != 0) {
auto saved_errno = errno;
dbgln("Protocol: pipe() failed: {}", strerror(saved_errno));
return nullptr;
}
auto output_stream = make<OutputFileStream>(fd_pair[1]);
output_stream->make_unbuffered();
auto job = HTTP::HttpJob::construct(request, *output_stream);
auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job, move(output_stream));
download->set_download_fd(fd_pair[0]);
job->start(); job->start();
return download; return download;
} }

View file

@ -35,7 +35,7 @@ public:
HttpProtocol(); HttpProtocol();
virtual ~HttpProtocol() override; virtual ~HttpProtocol() override;
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) override; virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) override;
}; };
} }

View file

@ -30,15 +30,21 @@
namespace ProtocolServer { namespace ProtocolServer {
HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job) HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
: Download(client) : Download(client, move(output_stream))
, m_job(job) , m_job(job)
{ {
m_job->on_headers_received = [this](auto& headers, auto response_code) {
if (response_code.has_value())
set_status_code(response_code.value());
set_response_headers(headers);
};
m_job->on_finish = [this](bool success) { m_job->on_finish = [this](bool success) {
if (auto* response = m_job->response()) { if (auto* response = m_job->response()) {
set_status_code(response->code()); set_status_code(response->code());
set_payload(response->payload());
set_response_headers(response->headers()); set_response_headers(response->headers());
set_downloaded_size(this->output_stream().size());
} }
// if we didn't know the total size, pretend that the download finished successfully // if we didn't know the total size, pretend that the download finished successfully
@ -68,9 +74,9 @@ HttpsDownload::~HttpsDownload()
m_job->shutdown(); m_job->shutdown();
} }
NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job) NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
{ {
return adopt_own(*new HttpsDownload(client, move(job))); return adopt_own(*new HttpsDownload(client, move(job), move(output_stream)));
} }
} }

View file

@ -36,10 +36,10 @@ namespace ProtocolServer {
class HttpsDownload final : public Download { class HttpsDownload final : public Download {
public: public:
virtual ~HttpsDownload() override; virtual ~HttpsDownload() override;
static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>); static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&);
private: private:
explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>); explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&);
virtual void set_certificate(String certificate, String key) override; virtual void set_certificate(String certificate, String key) override;

View file

@ -40,7 +40,7 @@ HttpsProtocol::~HttpsProtocol()
{ {
} }
OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, const ByteBuffer& request_body) OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, ReadonlyBytes body)
{ {
HTTP::HttpRequest request; HTTP::HttpRequest request;
if (method.equals_ignoring_case("post")) if (method.equals_ignoring_case("post"))
@ -49,9 +49,19 @@ OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const S
request.set_method(HTTP::HttpRequest::Method::GET); request.set_method(HTTP::HttpRequest::Method::GET);
request.set_url(url); request.set_url(url);
request.set_headers(headers); request.set_headers(headers);
request.set_body(request_body); request.set_body(body);
auto job = HTTP::HttpsJob::construct(request);
auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job); int fd_pair[2] { 0 };
if (pipe(fd_pair) != 0) {
auto saved_errno = errno;
dbgln("Protocol: pipe() failed: {}", strerror(saved_errno));
return nullptr;
}
auto output_stream = make<OutputFileStream>(fd_pair[1]);
output_stream->make_unbuffered();
auto job = HTTP::HttpsJob::construct(request, *output_stream);
auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job, move(output_stream));
download->set_download_fd(fd_pair[0]);
job->start(); job->start();
return download; return download;
} }

View file

@ -35,7 +35,7 @@ public:
HttpsProtocol(); HttpsProtocol();
virtual ~HttpsProtocol() override; virtual ~HttpsProtocol() override;
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) override; virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) override;
}; };
} }

View file

@ -37,7 +37,7 @@ public:
virtual ~Protocol(); virtual ~Protocol();
const String& name() const { return m_name; } const String& name() const { return m_name; }
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) = 0; virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) = 0;
static Protocol* find_by_name(const String&); static Protocol* find_by_name(const String&);

View file

@ -2,7 +2,8 @@ endpoint ProtocolClient = 13
{ {
// Download notifications // Download notifications
DownloadProgress(i32 download_id, Optional<u32> total_size, u32 downloaded_size) =| DownloadProgress(i32 download_id, Optional<u32> total_size, u32 downloaded_size) =|
DownloadFinished(i32 download_id, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, IPC::Dictionary response_headers) =| DownloadFinished(i32 download_id, bool success, u32 total_size) =|
HeadersBecameAvailable(i32 download_id, IPC::Dictionary response_headers, Optional<u32> status_code) =|
// Certificate requests // Certificate requests
CertificateRequested(i32 download_id) => () CertificateRequested(i32 download_id) => ()

View file

@ -3,14 +3,11 @@ endpoint ProtocolServer = 9
// Basic protocol // Basic protocol
Greet() => (i32 client_id) Greet() => (i32 client_id)
// FIXME: It would be nice if the kernel provided a way to avoid this
DisownSharedBuffer(i32 shbuf_id) => ()
// Test if a specific protocol is supported, e.g "http" // Test if a specific protocol is supported, e.g "http"
IsSupportedProtocol(String protocol) => (bool supported) IsSupportedProtocol(String protocol) => (bool supported)
// Download API // Download API
StartDownload(String method, URL url, IPC::Dictionary request_headers, String request_body) => (i32 download_id) StartDownload(String method, URL url, IPC::Dictionary request_headers, ByteBuffer request_body) => (i32 download_id, IPC::File response_fd)
StopDownload(i32 download_id) => (bool success) StopDownload(i32 download_id) => (bool success)
SetCertificate(i32 download_id, String certificate, String key) => (bool success) SetCertificate(i32 download_id, String certificate, String key) => (bool success)
} }

View file

@ -35,7 +35,7 @@
int main(int, char**) int main(int, char**)
{ {
if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr", nullptr) < 0) { if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr sendfd recvfd", nullptr) < 0) {
perror("pledge"); perror("pledge");
return 1; return 1;
} }
@ -45,7 +45,7 @@ int main(int, char**)
Core::EventLoop event_loop; Core::EventLoop event_loop;
// FIXME: Establish a connection to LookupServer and then drop "unix"? // FIXME: Establish a connection to LookupServer and then drop "unix"?
if (pledge("stdio inet shared_buffer accept unix", nullptr) < 0) { if (pledge("stdio inet shared_buffer accept unix sendfd recvfd", nullptr) < 0) {
perror("pledge"); perror("pledge");
return 1; return 1;
} }

View file

@ -24,6 +24,7 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/ */
#include <AK/FileStream.h>
#include <AK/GenericLexer.h> #include <AK/GenericLexer.h>
#include <AK/LexicalPath.h> #include <AK/LexicalPath.h>
#include <AK/NumberFormat.h> #include <AK/NumberFormat.h>
@ -116,29 +117,50 @@ private:
bool m_might_be_wrong { false }; bool m_might_be_wrong { false };
}; };
static void do_write(ReadonlyBytes payload) template<typename ConditionT>
{ class ConditionalOutputFileStream final : public OutputFileStream {
size_t length_remaining = payload.size(); public:
size_t length_written = 0; template<typename... Args>
while (length_remaining > 0) { ConditionalOutputFileStream(ConditionT&& condition, Args... args)
auto nwritten = fwrite(payload.data() + length_written, sizeof(char), length_remaining, stdout); : OutputFileStream(args...)
if (nwritten > 0) { , m_condition(condition)
length_remaining -= nwritten; {
length_written += nwritten; }
continue;
}
if (feof(stdout)) { ~ConditionalOutputFileStream()
fprintf(stderr, "pro: unexpected eof while writing\n"); {
if (!m_condition())
return; return;
}
if (ferror(stdout)) { if (!m_buffer.is_empty()) {
fprintf(stderr, "pro: error while writing\n"); OutputFileStream::write(m_buffer);
return; m_buffer.clear();
} }
} }
}
private:
size_t write(ReadonlyBytes bytes) override
{
if (!m_condition()) {
write_to_buffer:;
m_buffer.append(bytes.data(), bytes.size());
return bytes.size();
}
if (!m_buffer.is_empty()) {
auto size = OutputFileStream::write(m_buffer);
m_buffer = m_buffer.slice(size, m_buffer.size() - size);
}
if (!m_buffer.is_empty())
goto write_to_buffer;
return OutputFileStream::write(bytes);
}
ConditionT m_condition;
ByteBuffer m_buffer;
};
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
@ -195,6 +217,8 @@ int main(int argc, char** argv)
timeval prev_time, current_time, time_diff; timeval prev_time, current_time, time_diff;
gettimeofday(&prev_time, nullptr); gettimeofday(&prev_time, nullptr);
bool received_actual_headers = false;
download->on_progress = [&](Optional<u32> maybe_total_size, u32 downloaded_size) { download->on_progress = [&](Optional<u32> maybe_total_size, u32 downloaded_size) {
fprintf(stderr, "\r\033[2K"); fprintf(stderr, "\r\033[2K");
if (maybe_total_size.has_value()) { if (maybe_total_size.has_value()) {
@ -215,10 +239,13 @@ int main(int argc, char** argv)
previous_downloaded_size = downloaded_size; previous_downloaded_size = downloaded_size;
prev_time = current_time; prev_time = current_time;
}; };
download->on_finish = [&](bool success, auto payload, auto, auto& response_headers, auto) {
fprintf(stderr, "\033]9;-1;\033\\"); if (save_at_provided_name) {
fprintf(stderr, "\n"); download->on_headers_received = [&](auto& response_headers, auto status_code) {
if (success && save_at_provided_name) { if (received_actual_headers)
return;
dbg() << "Received headers! response code = " << status_code.value_or(0);
received_actual_headers = true; // And not trailers!
String output_name; String output_name;
if (auto content_disposition = response_headers.get("Content-Disposition"); content_disposition.has_value()) { if (auto content_disposition = response_headers.get("Content-Disposition"); content_disposition.has_value()) {
auto& value = content_disposition.value(); auto& value = content_disposition.value();
@ -245,17 +272,26 @@ int main(int argc, char** argv)
if (freopen(output_name.characters(), "w", stdout) == nullptr) { if (freopen(output_name.characters(), "w", stdout) == nullptr) {
perror("freopen"); perror("freopen");
success = false; // oops!
loop.quit(1); loop.quit(1);
return;
} }
} };
if (success) }
do_write(payload); download->on_finish = [&](bool success, auto) {
else fprintf(stderr, "\033]9;-1;\033\\");
fprintf(stderr, "\n");
if (!success)
fprintf(stderr, "Download failed :(\n"); fprintf(stderr, "Download failed :(\n");
loop.quit(0); loop.quit(0);
}; };
auto output_stream = ConditionalOutputFileStream { [&] { return save_at_provided_name ? received_actual_headers : true; }, stdout };
download->stream_into(output_stream);
dbgprintf("started download with id %d\n", download->id()); dbgprintf("started download with id %d\n", download->id());
return loop.exec(); auto rc = loop.exec();
// FIXME: This shouldn't be needed.
fclose(stdout);
return rc;
} }