diff --git a/Libraries/LibProtocol/Client.cpp b/Libraries/LibProtocol/Client.cpp index 86bb8d6c2f..baeb14e9ce 100644 --- a/Libraries/LibProtocol/Client.cpp +++ b/Libraries/LibProtocol/Client.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace LibProtocol { @@ -6,6 +7,7 @@ namespace LibProtocol { Client::Client() : ConnectionNG(*this, "/tmp/psportal") { + handshake(); } void Client::handshake() @@ -20,27 +22,36 @@ bool Client::is_supported_protocol(const String& protocol) return send_sync(protocol)->supported(); } -i32 Client::start_download(const String& url) +RefPtr Client::start_download(const String& url) { - return send_sync(url)->download_id(); + i32 download_id = send_sync(url)->download_id(); + auto download = Download::create_from_id({}, *this, download_id); + m_downloads.set(download_id, download); + return download; } -bool Client::stop_download(i32 download_id) +bool Client::stop_download(Badge, Download& download) { - return send_sync(download_id)->success(); + if (!m_downloads.contains(download.id())) + return false; + return send_sync(download.id())->success(); } void Client::handle(const ProtocolClient::DownloadFinished& message) { - if (on_download_finish) - on_download_finish(message.download_id(), message.success(), message.total_size(), message.shared_buffer_id()); + RefPtr download; + if ((download = m_downloads.get(message.download_id()).value_or(nullptr))) { + download->did_finish({}, message.success(), message.total_size(), message.shared_buffer_id()); + } send_sync(message.shared_buffer_id()); + m_downloads.remove(message.download_id()); } void Client::handle(const ProtocolClient::DownloadProgress& message) { - if (on_download_progress) - on_download_progress(message.download_id(), message.total_size(), message.downloaded_size()); + if (auto download = m_downloads.get(message.download_id()).value_or(nullptr)) { + download->did_progress({}, message.total_size(), message.downloaded_size()); + } } } diff --git a/Libraries/LibProtocol/Client.h b/Libraries/LibProtocol/Client.h index ed816f41ca..f334e2a47d 100644 --- a/Libraries/LibProtocol/Client.h +++ b/Libraries/LibProtocol/Client.h @@ -6,6 +6,8 @@ namespace LibProtocol { +class Download; + class Client : public IPC::Client::ConnectionNG , public ProtocolClientEndpoint { C_OBJECT(Client) @@ -15,15 +17,15 @@ public: virtual void handshake() override; bool is_supported_protocol(const String&); - i32 start_download(const String& url); - bool stop_download(i32 download_id); + RefPtr start_download(const String& url); - Function on_download_finish; - Function on_download_progress; + bool stop_download(Badge, Download&); private: virtual void handle(const ProtocolClient::DownloadProgress&) override; virtual void handle(const ProtocolClient::DownloadFinished&) override; + + HashMap> m_downloads; }; } diff --git a/Libraries/LibProtocol/Download.cpp b/Libraries/LibProtocol/Download.cpp new file mode 100644 index 0000000000..51b5712040 --- /dev/null +++ b/Libraries/LibProtocol/Download.cpp @@ -0,0 +1,38 @@ +#include +#include +#include + +namespace LibProtocol { + +Download::Download(Client& client, i32 download_id) + : m_client(client.make_weak_ptr()) + , m_download_id(download_id) +{ +} + +bool Download::stop() +{ + return m_client->stop_download({}, *this); +} + +void Download::did_finish(Badge, bool success, u32 total_size, i32 shared_buffer_id) +{ + if (!on_finish) + return; + + ByteBuffer payload; + RefPtr shared_buffer; + if (success && shared_buffer_id != -1) { + shared_buffer = SharedBuffer::create_from_shared_buffer_id(shared_buffer_id); + payload = ByteBuffer::wrap(shared_buffer->data(), total_size); + } + on_finish(success, payload, move(shared_buffer)); +} + +void Download::did_progress(Badge, u32 total_size, u32 downloaded_size) +{ + if (on_progress) + on_progress(total_size, downloaded_size); +} + +} diff --git a/Libraries/LibProtocol/Download.h b/Libraries/LibProtocol/Download.h new file mode 100644 index 0000000000..7cc5a7796d --- /dev/null +++ b/Libraries/LibProtocol/Download.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include +#include +#include + +class SharedBuffer; + +namespace LibProtocol { + +class Client; + +class Download : public RefCounted { +public: + static NonnullRefPtr create_from_id(Badge, Client& client, i32 download_id) + { + return adopt(*new Download(client, download_id)); + } + + int id() const { return m_download_id; } + bool stop(); + + Function payload_storage)> on_finish; + Function on_progress; + + void did_finish(Badge, bool success, u32 total_size, i32 shared_buffer_id); + void did_progress(Badge, u32 total_size, u32 downloaded_size); + +private: + explicit Download(Client&, i32 download_id); + WeakPtr m_client; + int m_download_id { -1 }; +}; + +} diff --git a/Libraries/LibProtocol/Makefile b/Libraries/LibProtocol/Makefile index 16c751bbde..0f34357db9 100644 --- a/Libraries/LibProtocol/Makefile +++ b/Libraries/LibProtocol/Makefile @@ -1,6 +1,7 @@ include ../../Makefile.common OBJS = \ + Download.o \ Client.o LIBRARY = libprotocol.a diff --git a/Userland/pro.cpp b/Userland/pro.cpp index 9c9a6c83c8..7209e60f7a 100644 --- a/Userland/pro.cpp +++ b/Userland/pro.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include int main(int argc, char** argv) @@ -20,25 +21,19 @@ int main(int argc, char** argv) CEventLoop loop; auto protocol_client = LibProtocol::Client::construct(); - protocol_client->handshake(); - protocol_client->on_download_finish = [&](i32 download_id, bool success, u32 total_size, i32 shared_buffer_id) { - dbgprintf("download %d finished, success=%u, shared_buffer_id=%d\n", download_id, success, shared_buffer_id); - if (success) { - ASSERT(shared_buffer_id != -1); - auto shared_buffer = SharedBuffer::create_from_shared_buffer_id(shared_buffer_id); - auto payload_bytes = ByteBuffer::wrap(shared_buffer->data(), total_size); - write(STDOUT_FILENO, payload_bytes.data(), payload_bytes.size()); - } + auto download = protocol_client->start_download(url.to_string()); + download->on_progress = [](u32 total_size, u32 downloaded_size) { + dbgprintf("download progress: %u / %u\n", downloaded_size, total_size); + }; + download->on_finish = [&](bool success, auto& payload, auto) { + if (success) + write(STDOUT_FILENO, payload.data(), payload.size()); + else + fprintf(stderr, "Download failed :(\n"); loop.quit(0); }; - - protocol_client->on_download_progress = [&](i32 download_id, u32 total_size, u32 downloaded_size) { - dbgprintf("download %d progress: %u / %u\n", download_id, downloaded_size, total_size); - }; - - i32 download_id = protocol_client->start_download(url.to_string()); - dbgprintf("started download with id %d\n", download_id); + dbgprintf("started download with id %d\n", download->id()); return loop.exec(); }