1
Fork 0
mirror of https://github.com/RGBCube/serenity synced 2025-05-31 19:58:11 +00:00

ProtocolServer+LibProtocol: Introduce a server for handling downloads

This patch adds ProtocolServer, a server that handles network requests
on behalf of its clients. The first protocol implemented is HTTP.

The idea here is to use a plug-in architecture where any number of
protocols can be added and implemented without having to mess around
with each client program that wants to use the protocol.

A simple client API is provided through LibProtocol::Client. :^)
This commit is contained in:
Andreas Kling 2019-11-23 21:45:33 +01:00
parent 61f611bf3c
commit fd4349a9f2
21 changed files with 475 additions and 0 deletions

View file

@ -107,6 +107,7 @@ cp ../Servers/WindowServer/WindowServer mnt/bin/WindowServer
cp ../Servers/AudioServer/AudioServer mnt/bin/AudioServer
cp ../Servers/TTYServer/TTYServer mnt/bin/TTYServer
cp ../Servers/TelnetServer/TelnetServer mnt/bin/TelnetServer
cp ../Servers/ProtocolServer/ProtocolServer mnt/bin/ProtocolServer
cp ../Shell/Shell mnt/bin/Shell
echo "done"

View file

@ -31,6 +31,7 @@ build_targets="$build_targets ../Libraries/LibPthread"
# Build IPC servers before their client code to ensure the IPC definitions are available.
build_targets="$build_targets ../Servers/AudioServer"
build_targets="$build_targets ../Servers/LookupServer"
build_targets="$build_targets ../Servers/ProtocolServer"
build_targets="$build_targets ../AK"
@ -42,6 +43,7 @@ build_targets="$build_targets ../Libraries/LibM"
build_targets="$build_targets ../Libraries/LibPCIDB"
build_targets="$build_targets ../Libraries/LibVT"
build_targets="$build_targets ../Libraries/LibMarkdown"
build_targets="$build_targets ../Libraries/LibProtocol"
build_targets="$build_targets ../Applications/About"
build_targets="$build_targets ../Applications/Calculator"

View file

@ -0,0 +1,45 @@
#include <LibProtocol/Client.h>
#include <SharedBuffer.h>
namespace LibProtocol {
Client::Client()
: ConnectionNG(*this, "/tmp/psportal")
{
}
void Client::handshake()
{
auto response = send_sync<ProtocolServer::Greet>(getpid());
set_server_pid(response->server_pid());
set_my_client_id(response->client_id());
}
bool Client::is_supported_protocol(const String& protocol)
{
return send_sync<ProtocolServer::IsSupportedProtocol>(protocol)->supported();
}
i32 Client::start_download(const String& url)
{
return send_sync<ProtocolServer::StartDownload>(url)->download_id();
}
bool Client::stop_download(i32 download_id)
{
return send_sync<ProtocolServer::StopDownload>(download_id)->success();
}
void Client::handle(const ProtocolClient::DownloadFinished& message)
{
if (on_download_finish)
on_download_finish(message.download_id(), message.success());
}
void Client::handle(const ProtocolClient::DownloadProgress& message)
{
if (on_download_progress)
on_download_progress(message.download_id(), message.total_size(), message.downloaded_size());
}
}

View file

@ -0,0 +1,29 @@
#pragma once
#include <LibCore/CoreIPCClient.h>
#include <ProtocolServer/ProtocolClientEndpoint.h>
#include <ProtocolServer/ProtocolServerEndpoint.h>
namespace LibProtocol {
class Client : public IPC::Client::ConnectionNG<ProtocolClientEndpoint, ProtocolServerEndpoint>
, public ProtocolClientEndpoint {
C_OBJECT(Client)
public:
Client();
virtual void handshake() override;
bool is_supported_protocol(const String&);
i32 start_download(const String& url);
bool stop_download(i32 download_id);
Function<void(i32 download_id, bool success)> on_download_finish;
Function<void(i32 download_id, u64 total_size, u64 downloaded_size)> on_download_progress;
private:
virtual void handle(const ProtocolClient::DownloadProgress&) override;
virtual void handle(const ProtocolClient::DownloadFinished&) override;
};
}

View file

@ -0,0 +1,20 @@
include ../../Makefile.common
OBJS = \
Client.o
LIBRARY = libprotocol.a
DEFINES += -DUSERLAND
all: $(LIBRARY)
$(LIBRARY): $(OBJS)
@echo "LIB $@"; $(AR) rcs $@ $(OBJS) $(LIBS)
.cpp.o:
@echo "CXX $<"; $(CXX) $(CXXFLAGS) -o $@ -c $<
-include $(OBJS:%.o=%.d)
clean:
@echo "CLEAN"; rm -f $(LIBRARY) $(OBJS) *.d

View file

@ -28,6 +28,7 @@ LDFLAGS = \
-L$(SERENITY_BASE_DIR)/Libraries/LibMarkdown \
-L$(SERENITY_BASE_DIR)/Libraries/LibThread \
-L$(SERENITY_BASE_DIR)/Libraries/LibVT \
-L$(SERENITY_BASE_DIR)/Libraries/LibProtocol \
-L$(SERENITY_BASE_DIR)/Libraries/LibAudio
CLANG_FLAGS = -Wconsumed -m32 -ffreestanding -march=i686

View file

@ -0,0 +1,55 @@
#include <ProtocolServer/Download.h>
#include <ProtocolServer/PSClientConnection.h>
// FIXME: What about rollover?
static i32 s_next_id = 1;
static HashMap<i32, RefPtr<Download>>& all_downloads()
{
static HashMap<i32, RefPtr<Download>> map;
return map;
}
Download* Download::find_by_id(i32 id)
{
return all_downloads().get(id).value_or(nullptr);
}
Download::Download(PSClientConnection& client)
: m_id(s_next_id++)
, m_client(client.make_weak_ptr())
{
all_downloads().set(m_id, this);
}
Download::~Download()
{
}
void Download::stop()
{
all_downloads().remove(m_id);
}
void Download::did_finish(bool success)
{
if (!m_client) {
dbg() << "Download::did_finish() after the client already disconnected.";
return;
}
m_client->did_finish_download({}, *this, success);
all_downloads().remove(m_id);
}
void Download::did_progress(size_t total_size, size_t downloaded_size)
{
if (!m_client) {
// FIXME: We should also abort the download in this situation, I guess!
dbg() << "Download::did_progress() after the client already disconnected.";
return;
}
m_total_size = total_size;
m_downloaded_size = downloaded_size;
m_client->did_progress_download({}, *this);
}

View file

@ -0,0 +1,35 @@
#pragma once
#include <AK/RefCounted.h>
#include <AK/URL.h>
#include <AK/WeakPtr.h>
class PSClientConnection;
class Download : public RefCounted<Download> {
public:
virtual ~Download();
static Download* find_by_id(i32);
i32 id() const { return m_id; }
URL url() const { return m_url; }
size_t total_size() const { return m_total_size; }
size_t downloaded_size() const { return m_downloaded_size; }
void stop();
protected:
explicit Download(PSClientConnection&);
void did_finish(bool success);
void did_progress(size_t total_size, size_t downloaded_size);
private:
i32 m_id;
URL m_url;
size_t m_total_size { 0 };
size_t m_downloaded_size { 0 };
WeakPtr<PSClientConnection> m_client;
};

View file

@ -0,0 +1,20 @@
#include <LibCore/CHttpJob.h>
#include <ProtocolServer/HttpDownload.h>
HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<CHttpJob>&& job)
: Download(client)
, m_job(job)
{
m_job->on_finish = [this](bool success) {
did_finish(success);
};
}
HttpDownload::~HttpDownload()
{
}
NonnullRefPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, PSClientConnection& client, NonnullRefPtr<CHttpJob>&& job)
{
return adopt(*new HttpDownload(client, move(job)));
}

View file

@ -0,0 +1,18 @@
#pragma once
#include <AK/Badge.h>
#include <ProtocolServer/Download.h>
class CHttpJob;
class HttpProtocol;
class HttpDownload final : public Download {
public:
virtual ~HttpDownload() override;
static NonnullRefPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, PSClientConnection&, NonnullRefPtr<CHttpJob>&&);
private:
explicit HttpDownload(PSClientConnection&, NonnullRefPtr<CHttpJob>&&);
NonnullRefPtr<CHttpJob> m_job;
};

View file

@ -0,0 +1,24 @@
#include <LibCore/CHttpJob.h>
#include <LibCore/CHttpRequest.h>
#include <ProtocolServer/HttpDownload.h>
#include <ProtocolServer/HttpProtocol.h>
HttpProtocol::HttpProtocol()
: Protocol("http")
{
}
HttpProtocol::~HttpProtocol()
{
}
RefPtr<Download> HttpProtocol::start_download(PSClientConnection& client, const URL& url)
{
CHttpRequest request;
request.set_method(CHttpRequest::Method::GET);
request.set_url(url);
auto job = request.schedule();
if (!job)
return nullptr;
return HttpDownload::create_with_job({}, client, (CHttpJob&)*job);
}

View file

@ -0,0 +1,11 @@
#pragma once
#include <ProtocolServer/Protocol.h>
class HttpProtocol final : public Protocol {
public:
HttpProtocol();
virtual ~HttpProtocol() override;
virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
};

View file

@ -0,0 +1,35 @@
include ../../Makefile.common
OBJS = \
PSClientConnection.o \
Protocol.o \
Download.o \
HttpProtocol.o \
HttpDownload.o \
main.o
APP = ProtocolServer
DEFINES += -DUSERLAND
all: $(APP)
*.cpp: ProtocolServerEndpoint.h ProtocolClientEndpoint.h
ProtocolServerEndpoint.h: ProtocolServer.ipc
@echo "IPC $<"; $(IPCCOMPILER) $< > $@
ProtocolClientEndpoint.h: ProtocolClient.ipc
@echo "IPC $<"; $(IPCCOMPILER) $< > $@
$(APP): $(OBJS)
$(LD) -o $(APP) $(LDFLAGS) $(OBJS) -lc -lcore -lipc -ldraw
.cpp.o:
@echo "CXX $<"; $(CXX) $(CXXFLAGS) -o $@ -c $<
-include $(OBJS:%.o=%.d)
clean:
@echo "CLEAN"; rm -f $(APP) $(OBJS) *.d ProtocolClientEndpoint.h ProtocolServerEndpoint.h

View file

@ -0,0 +1,63 @@
#include <ProtocolServer/Download.h>
#include <ProtocolServer/PSClientConnection.h>
#include <ProtocolServer/Protocol.h>
#include <ProtocolServer/ProtocolClientEndpoint.h>
static HashMap<int, RefPtr<PSClientConnection>> s_connections;
PSClientConnection::PSClientConnection(CLocalSocket& socket, int client_id)
: ConnectionNG(*this, socket, client_id)
{
s_connections.set(client_id, *this);
}
PSClientConnection::~PSClientConnection()
{
}
void PSClientConnection::die()
{
s_connections.remove(client_id());
}
OwnPtr<ProtocolServer::IsSupportedProtocolResponse> PSClientConnection::handle(const ProtocolServer::IsSupportedProtocol& message)
{
bool supported = Protocol::find_by_name(message.protocol().to_lowercase());
return make<ProtocolServer::IsSupportedProtocolResponse>(supported);
}
OwnPtr<ProtocolServer::StartDownloadResponse> PSClientConnection::handle(const ProtocolServer::StartDownload& message)
{
URL url(message.url());
ASSERT(url.is_valid());
auto* protocol = Protocol::find_by_name(url.protocol());
ASSERT(protocol);
auto download = protocol->start_download(*this, url);
return make<ProtocolServer::StartDownloadResponse>(download->id());
}
OwnPtr<ProtocolServer::StopDownloadResponse> PSClientConnection::handle(const ProtocolServer::StopDownload& message)
{
auto* download = Download::find_by_id(message.download_id());
bool success = false;
if (download) {
download->stop();
}
return make<ProtocolServer::StopDownloadResponse>(success);
}
void PSClientConnection::did_finish_download(Badge<Download>, Download& download, bool success)
{
post_message(ProtocolClient::DownloadFinished(download.id(), success));
}
void PSClientConnection::did_progress_download(Badge<Download>, Download& download)
{
post_message(ProtocolClient::DownloadProgress(download.id(), download.total_size(), download.downloaded_size()));
}
OwnPtr<ProtocolServer::GreetResponse> PSClientConnection::handle(const ProtocolServer::Greet& message)
{
set_client_pid(message.client_pid());
return make<ProtocolServer::GreetResponse>(getpid(), client_id());
}

View file

@ -0,0 +1,26 @@
#pragma once
#include <AK/Badge.h>
#include <LibCore/CoreIPCServer.h>
#include <ProtocolServer/ProtocolServerEndpoint.h>
class Download;
class PSClientConnection final : public IPC::Server::ConnectionNG<ProtocolServerEndpoint>
, public ProtocolServerEndpoint {
C_OBJECT(PSClientConnection)
public:
explicit PSClientConnection(CLocalSocket&, int client_id);
~PSClientConnection() override;
virtual void die() override;
void did_finish_download(Badge<Download>, Download&, bool success);
void did_progress_download(Badge<Download>, Download&);
private:
virtual OwnPtr<ProtocolServer::GreetResponse> handle(const ProtocolServer::Greet&) override;
virtual OwnPtr<ProtocolServer::IsSupportedProtocolResponse> handle(const ProtocolServer::IsSupportedProtocol&) override;
virtual OwnPtr<ProtocolServer::StartDownloadResponse> handle(const ProtocolServer::StartDownload&) override;
virtual OwnPtr<ProtocolServer::StopDownloadResponse> handle(const ProtocolServer::StopDownload&) override;
};

View file

@ -0,0 +1,23 @@
#include <AK/HashMap.h>
#include <ProtocolServer/Protocol.h>
static HashMap<String, Protocol*>& all_protocols()
{
static HashMap<String, Protocol*> map;
return map;
}
Protocol* Protocol::find_by_name(const String& name)
{
return all_protocols().get(name).value_or(nullptr);
}
Protocol::Protocol(const String& name)
{
all_protocols().set(name, this);
}
Protocol::~Protocol()
{
ASSERT_NOT_REACHED();
}

View file

@ -0,0 +1,23 @@
#pragma once
#include <AK/RefPtr.h>
#include <AK/URL.h>
class Download;
class PSClientConnection;
class Protocol {
public:
virtual ~Protocol();
const String& name() const { return m_name; }
virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) = 0;
static Protocol* find_by_name(const String&);
protected:
explicit Protocol(const String& name);
private:
String m_name;
};

View file

@ -0,0 +1,6 @@
endpoint ProtocolClient = 13
{
// Download notifications
DownloadProgress(i32 download_id, u32 total_size, u32 downloaded_size) =|
DownloadFinished(i32 download_id, bool success) =|
}

View file

@ -0,0 +1,12 @@
endpoint ProtocolServer = 9
{
// Basic protocol
Greet(i32 client_pid) => (i32 server_pid, i32 client_id)
// Test if a specific protocol is supported, e.g "http"
IsSupportedProtocol(String protocol) => (bool supported)
// Download API
StartDownload(String url) => (i32 download_id)
StopDownload(i32 download_id) => (bool success)
}

View file

@ -0,0 +1,25 @@
#include <LibCore/CEventLoop.h>
#include <LibCore/CLocalServer.h>
#include <LibCore/CoreIPCServer.h>
#include <ProtocolServer/HttpProtocol.h>
#include <ProtocolServer/PSClientConnection.h>
int main(int, char**)
{
CEventLoop event_loop;
(void)*new HttpProtocol;
auto server = CLocalServer::construct();
unlink("/tmp/psportal");
server->listen("/tmp/psportal");
server->on_ready_to_accept = [&] {
auto client_socket = server->accept();
if (!client_socket) {
dbg() << "ProtocolServer: accept failed.";
return;
}
static int s_next_client_id = 0;
int client_id = ++s_next_client_id;
IPC::Server::new_connection_ng_for_client<PSClientConnection>(*client_socket, client_id);
};
return event_loop.exec();
}

View file

@ -106,6 +106,7 @@ int main(int, char**)
signal(SIGCHLD, sigchld_handler);
start_process("/bin/ProtocolServer", {}, lowest_prio);
start_process("/bin/LookupServer", {}, lowest_prio);
start_process("/bin/WindowServer", {}, highest_prio);
start_process("/bin/AudioServer", {}, highest_prio);