diff --git a/Libraries/LibProtocol/Client.cpp b/Libraries/LibProtocol/Client.cpp index d37317d9dd..86bb8d6c2f 100644 --- a/Libraries/LibProtocol/Client.cpp +++ b/Libraries/LibProtocol/Client.cpp @@ -33,7 +33,8 @@ bool Client::stop_download(i32 download_id) void Client::handle(const ProtocolClient::DownloadFinished& message) { if (on_download_finish) - on_download_finish(message.download_id(), message.success()); + on_download_finish(message.download_id(), message.success(), message.total_size(), message.shared_buffer_id()); + send_sync(message.shared_buffer_id()); } void Client::handle(const ProtocolClient::DownloadProgress& message) diff --git a/Libraries/LibProtocol/Client.h b/Libraries/LibProtocol/Client.h index 86b4867f60..ed816f41ca 100644 --- a/Libraries/LibProtocol/Client.h +++ b/Libraries/LibProtocol/Client.h @@ -18,7 +18,7 @@ public: i32 start_download(const String& url); bool stop_download(i32 download_id); - Function on_download_finish; + Function on_download_finish; Function on_download_progress; private: diff --git a/Servers/ProtocolServer/Download.cpp b/Servers/ProtocolServer/Download.cpp index d6fe75bef4..a2d6904ce0 100644 --- a/Servers/ProtocolServer/Download.cpp +++ b/Servers/ProtocolServer/Download.cpp @@ -31,6 +31,12 @@ void Download::stop() all_downloads().remove(m_id); } +void Download::set_payload(const ByteBuffer& payload) +{ + m_payload = payload; + m_total_size = payload.size(); +} + void Download::did_finish(bool success) { if (!m_client) { diff --git a/Servers/ProtocolServer/Download.h b/Servers/ProtocolServer/Download.h index d12a6b7e52..eb2e1ac529 100644 --- a/Servers/ProtocolServer/Download.h +++ b/Servers/ProtocolServer/Download.h @@ -17,6 +17,7 @@ public: size_t total_size() const { return m_total_size; } size_t downloaded_size() const { return m_downloaded_size; } + const ByteBuffer& payload() const { return m_payload; } void stop(); @@ -25,11 +26,13 @@ protected: void did_finish(bool success); void did_progress(size_t total_size, size_t downloaded_size); + void set_payload(const ByteBuffer&); private: i32 m_id; URL m_url; size_t m_total_size { 0 }; size_t m_downloaded_size { 0 }; + ByteBuffer m_payload; WeakPtr m_client; }; diff --git a/Servers/ProtocolServer/HttpDownload.cpp b/Servers/ProtocolServer/HttpDownload.cpp index 76dba63491..b20c553a58 100644 --- a/Servers/ProtocolServer/HttpDownload.cpp +++ b/Servers/ProtocolServer/HttpDownload.cpp @@ -1,4 +1,5 @@ #include +#include #include HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr&& job) @@ -6,6 +7,7 @@ HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr&& , m_job(job) { m_job->on_finish = [this](bool success) { + set_payload(m_job->response()->payload()); did_finish(success); }; } diff --git a/Servers/ProtocolServer/PSClientConnection.cpp b/Servers/ProtocolServer/PSClientConnection.cpp index fb93268efe..be0e67129d 100644 --- a/Servers/ProtocolServer/PSClientConnection.cpp +++ b/Servers/ProtocolServer/PSClientConnection.cpp @@ -2,6 +2,7 @@ #include #include #include +#include static HashMap> s_connections; @@ -48,7 +49,15 @@ OwnPtr PSClientConnection::handle(const Pr void PSClientConnection::did_finish_download(Badge, Download& download, bool success) { - post_message(ProtocolClient::DownloadFinished(download.id(), success)); + RefPtr buffer; + if (success && !download.payload().is_null()) { + buffer = SharedBuffer::create_with_size(download.payload().size()); + memcpy(buffer->data(), download.payload().data(), download.payload().size()); + buffer->seal(); + buffer->share_with(client_pid()); + m_shared_buffers.set(buffer->shared_buffer_id(), buffer); + } + post_message(ProtocolClient::DownloadFinished(download.id(), success, download.total_size(), buffer ? buffer->shared_buffer_id() : -1)); } void PSClientConnection::did_progress_download(Badge, Download& download) @@ -61,3 +70,9 @@ OwnPtr PSClientConnection::handle(const ProtocolS set_client_pid(message.client_pid()); return make(getpid(), client_id()); } + +OwnPtr PSClientConnection::handle(const ProtocolServer::DisownSharedBuffer& message) +{ + m_shared_buffers.remove(message.shared_buffer_id()); + return make(); +} diff --git a/Servers/ProtocolServer/PSClientConnection.h b/Servers/ProtocolServer/PSClientConnection.h index 190e9bd062..fd6b71f195 100644 --- a/Servers/ProtocolServer/PSClientConnection.h +++ b/Servers/ProtocolServer/PSClientConnection.h @@ -5,6 +5,7 @@ #include class Download; +class SharedBuffer; class PSClientConnection final : public IPC::Server::ConnectionNG , public ProtocolServerEndpoint { @@ -23,4 +24,7 @@ private: virtual OwnPtr handle(const ProtocolServer::IsSupportedProtocol&) override; virtual OwnPtr handle(const ProtocolServer::StartDownload&) override; virtual OwnPtr handle(const ProtocolServer::StopDownload&) override; + virtual OwnPtr handle(const ProtocolServer::DisownSharedBuffer&) override; + + HashMap> m_shared_buffers; }; diff --git a/Servers/ProtocolServer/ProtocolClient.ipc b/Servers/ProtocolServer/ProtocolClient.ipc index df88714f50..58a408add9 100644 --- a/Servers/ProtocolServer/ProtocolClient.ipc +++ b/Servers/ProtocolServer/ProtocolClient.ipc @@ -2,5 +2,5 @@ endpoint ProtocolClient = 13 { // Download notifications DownloadProgress(i32 download_id, u32 total_size, u32 downloaded_size) =| - DownloadFinished(i32 download_id, bool success) =| + DownloadFinished(i32 download_id, bool success, u32 total_size, i32 shared_buffer_id) =| } diff --git a/Servers/ProtocolServer/ProtocolServer.ipc b/Servers/ProtocolServer/ProtocolServer.ipc index 90af2d33ec..6ed9eba623 100644 --- a/Servers/ProtocolServer/ProtocolServer.ipc +++ b/Servers/ProtocolServer/ProtocolServer.ipc @@ -3,6 +3,9 @@ endpoint ProtocolServer = 9 // Basic protocol Greet(i32 client_pid) => (i32 server_pid, i32 client_id) + // FIXME: It would be nice if the kernel provided a way to avoid this + DisownSharedBuffer(i32 shared_buffer_id) => () + // Test if a specific protocol is supported, e.g "http" IsSupportedProtocol(String protocol) => (bool supported) diff --git a/Userland/pro.cpp b/Userland/pro.cpp index ccfc6955d5..215f8489c6 100644 --- a/Userland/pro.cpp +++ b/Userland/pro.cpp @@ -1,5 +1,6 @@ #include #include +#include #include int main(int argc, char** argv) @@ -13,8 +14,14 @@ int main(int argc, char** argv) printf("supports HTTP? %u\n", protocol_client->is_supported_protocol("http")); printf(" supports FTP? %u\n", protocol_client->is_supported_protocol("ftp")); - protocol_client->on_download_finish = [&](i32 download_id, bool success) { - printf("download %d finished, success=%u\n", download_id, success); + protocol_client->on_download_finish = [&](i32 download_id, bool success, u32 total_size, i32 shared_buffer_id) { + printf("download %d finished, success=%u, shared_buffer_id=%d\n", download_id, success, shared_buffer_id); + if (success) { + ASSERT(shared_buffer_id != -1); + auto shared_buffer = SharedBuffer::create_from_shared_buffer_id(shared_buffer_id); + auto payload_bytes = ByteBuffer::wrap(shared_buffer->data(), total_size); + write(STDOUT_FILENO, payload_bytes.data(), payload_bytes.size()); + } loop.quit(0); };