1
Fork 0
mirror of https://github.com/RGBCube/serenity synced 2025-06-01 08:28:11 +00:00

LibCore: Implement the Serenity Stream API classes

The Serenity Stream API is the name for the new set of classes intended
to replace IODevice and its descendants. It provides more flexibility
for subclasses by allowing each subclass to override all the possible
functionalities according to their wishes.

Stream is the base class which implements majority of the functionality
expected from a readable and writable stream. SeekableStream adds
seeking on top, and provides a couple utility methods which derive from
seek. Socket adds a couple of BSD socket utility functions such as
checking whether there is data available to read and checking the
pending bytes on the socket.

As for the concrete classes, there is File which is a SeekableStream and
is intended to operate on file-like objects; and TCPSocket, UDPSocket
and LocalSocket, which handle TCP, UDP and UNIX sockets respectively.

The concrete classes do not do buffering by default. For buffering
functionality, a set of augmentative classes named BufferedSeekable and
BufferedSocket have been implemented, intended to wrap a SeekableStream
and a Socket, respectively.
This commit is contained in:
sin-ack 2021-09-01 21:30:13 +00:00 committed by Ali Mohammad Pur
parent 69ef211925
commit 19e13117ad
3 changed files with 1462 additions and 0 deletions

View file

@ -27,6 +27,7 @@ set(SOURCES
Property.cpp
SecretString.cpp
Socket.cpp
Stream.cpp
StandardPaths.cpp
System.cpp
TCPServer.cpp

View file

@ -0,0 +1,547 @@
/*
* Copyright (c) 2018-2021, Andreas Kling <kling@serenityos.org>
* Copyright (c) 2021, sin-ack <sin-ack@protonmail.com>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#include "Stream.h"
#include <fcntl.h>
#include <netdb.h>
#include <poll.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#ifdef __serenity__
# include <serenity.h>
#endif
namespace Core::Stream {
bool Stream::read_or_error(Bytes buffer)
{
VERIFY(buffer.size());
size_t nread = 0;
do {
if (is_eof())
return false;
auto result = read(buffer.slice(nread));
if (result.is_error()) {
if (result.error().is_errno() && result.error().code() == EINTR) {
continue;
}
return false;
}
nread += result.value();
} while (nread < buffer.size());
return true;
}
bool Stream::write_or_error(ReadonlyBytes buffer)
{
VERIFY(buffer.size());
size_t nwritten = 0;
do {
auto result = write(buffer.slice(nwritten));
if (result.is_error()) {
if (result.error().is_errno() && result.error().code() == EINTR) {
continue;
}
return false;
}
nwritten += result.value();
} while (nwritten < buffer.size());
return true;
}
ErrorOr<off_t> SeekableStream::tell() const
{
// Seek with 0 and SEEK_CUR does not modify anything despite the const_cast,
// so it's safe to do this.
return const_cast<SeekableStream*>(this)->seek(0, SeekMode::FromCurrentPosition);
}
ErrorOr<off_t> SeekableStream::size()
{
auto original_position = TRY(tell());
auto seek_result = seek(0, SeekMode::FromEndPosition);
if (seek_result.is_error()) {
// Let's try to restore the original position, just in case.
auto restore_result = seek(original_position, SeekMode::SetPosition);
if (restore_result.is_error()) {
dbgln("Core::SeekableStream::size: Couldn't restore initial position, stream might have incorrect position now!");
}
return seek_result.release_error();
}
TRY(seek(original_position, SeekMode::SetPosition));
return seek_result.value();
}
ErrorOr<File> File::open(StringView const& filename, OpenMode mode, mode_t permissions)
{
File file { mode };
TRY(file.open_path(filename, permissions));
return file;
}
ErrorOr<File> File::adopt_fd(int fd, OpenMode mode)
{
if (fd < 0) {
return Error::from_errno(EBADF);
}
if (!has_any_flag(mode, OpenMode::ReadWrite)) {
dbgln("Core::File::adopt_fd: Attempting to adopt a file with neither Read nor Write specified in mode");
return Error::from_errno(EINVAL);
}
File file { mode };
file.m_fd = fd;
return file;
}
ErrorOr<void> File::open_path(StringView const& filename, mode_t permissions)
{
VERIFY(m_fd == -1);
int flags = 0;
if (has_flag(m_mode, OpenMode::ReadWrite)) {
flags |= O_RDWR | O_CREAT;
} else if (has_flag(m_mode, OpenMode::Read)) {
flags |= O_RDONLY;
} else if (has_flag(m_mode, OpenMode::Write)) {
flags |= O_WRONLY | O_CREAT;
bool should_truncate = !has_any_flag(m_mode, OpenMode::Append | OpenMode::MustBeNew);
if (should_truncate)
flags |= O_TRUNC;
}
if (has_flag(m_mode, OpenMode::Append))
flags |= O_APPEND;
if (has_flag(m_mode, OpenMode::Truncate))
flags |= O_TRUNC;
if (has_flag(m_mode, OpenMode::MustBeNew))
flags |= O_EXCL;
if (!has_flag(m_mode, OpenMode::KeepOnExec))
flags |= O_CLOEXEC;
if (!has_flag(m_mode, OpenMode::Nonblocking))
flags |= O_NONBLOCK;
#ifdef __serenity__
int fd = ::serenity_open(filename.characters_without_null_termination(), filename.length(), flags, permissions);
#else
String filename_with_null_terminator { filename };
int fd = ::open(filename_with_null_terminator.characters(), flags, permissions);
#endif
if (fd < 0) {
return Error::from_errno(errno);
}
m_fd = fd;
return {};
}
bool File::is_readable() const { return has_flag(m_mode, OpenMode::Read); }
bool File::is_writable() const { return has_flag(m_mode, OpenMode::Write); }
ErrorOr<size_t> File::read(Bytes buffer)
{
if (!has_flag(m_mode, OpenMode::Read)) {
// NOTE: POSIX says that if the fd is not open for reading, the call
// will return EBADF. Since we already know whether we can or
// can't read the file, let's avoid a syscall.
return EBADF;
}
ssize_t rc = ::read(m_fd, buffer.data(), buffer.size());
if (rc < 0) {
return Error::from_errno(errno);
}
m_last_read_was_eof = rc == 0;
return rc;
}
ErrorOr<size_t> File::write(ReadonlyBytes buffer)
{
if (!has_flag(m_mode, OpenMode::Write)) {
// NOTE: Same deal as Read.
return EBADF;
}
ssize_t rc = ::write(m_fd, buffer.data(), buffer.size());
if (rc < 0) {
return Error::from_errno(errno);
}
return rc;
}
bool File::is_eof() const { return m_last_read_was_eof; }
bool File::is_open() const { return m_fd >= 0; }
void File::close()
{
if (!is_open()) {
return;
}
// NOTE: The closing of the file can be interrupted by a signal, in which
// case EINTR will be returned by the close syscall. So let's try closing
// the file until we aren't interrupted by rude signals. :^)
int rc;
do {
rc = ::close(m_fd);
} while (rc < 0 && errno == EINTR);
VERIFY(rc == 0);
m_fd = -1;
}
ErrorOr<off_t> File::seek(i64 offset, SeekMode mode)
{
int syscall_mode;
switch (mode) {
case SeekMode::SetPosition:
syscall_mode = SEEK_SET;
break;
case SeekMode::FromCurrentPosition:
syscall_mode = SEEK_CUR;
break;
case SeekMode::FromEndPosition:
syscall_mode = SEEK_END;
break;
default:
VERIFY_NOT_REACHED();
}
off_t rc = lseek(m_fd, offset, syscall_mode);
if (rc < 0) {
return Error::from_errno(errno);
}
m_last_read_was_eof = false;
return rc;
}
ErrorOr<int> Socket::create_fd(SocketDomain domain, SocketType type)
{
int socket_domain;
switch (domain) {
case SocketDomain::Inet:
socket_domain = AF_INET;
break;
case SocketDomain::Local:
socket_domain = AF_LOCAL;
break;
default:
VERIFY_NOT_REACHED();
}
int socket_type;
switch (type) {
case SocketType::Stream:
socket_type = SOCK_STREAM;
break;
case SocketType::Datagram:
socket_type = SOCK_DGRAM;
break;
default:
VERIFY_NOT_REACHED();
}
int rc = ::socket(socket_domain, socket_type, 0);
if (rc < 0) {
return Error::from_errno(errno);
}
return rc;
}
Result<IPv4Address, SocketError> Socket::resolve_host(String const& host, SocketType type)
{
int socket_type;
switch (type) {
case SocketType::Stream:
socket_type = SOCK_STREAM;
break;
case SocketType::Datagram:
socket_type = SOCK_DGRAM;
break;
default:
VERIFY_NOT_REACHED();
}
struct addrinfo hints = {};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = socket_type;
hints.ai_flags = 0;
hints.ai_protocol = 0;
struct addrinfo* results = nullptr;
int rc = getaddrinfo(host.characters(), nullptr, &hints, &results);
if (rc != 0) {
if (rc == EAI_SYSTEM) {
return SocketError { Error::from_errno(errno) };
} else {
return SocketError { static_cast<GetAddrInfoError>(rc) };
}
}
auto socket_address = bit_cast<struct sockaddr_in*>(results->ai_addr);
NetworkOrdered<u32> network_ordered_address { socket_address->sin_addr.s_addr };
freeaddrinfo(results);
return IPv4Address { network_ordered_address };
}
ErrorOr<void> Socket::connect_local(int fd, String const& path)
{
auto address = SocketAddress::local(path);
auto maybe_sockaddr = address.to_sockaddr_un();
if (!maybe_sockaddr.has_value()) {
dbgln("Core::Stream::Socket::connect_local: Could not obtain a sockaddr_un");
return Error::from_errno(EINVAL);
}
auto addr = maybe_sockaddr.release_value();
int rc = ::connect(fd, bit_cast<struct sockaddr*>(&addr), sizeof(addr));
if (rc < 0) {
return Error::from_errno(errno);
}
return {};
}
ErrorOr<void> Socket::connect_inet(int fd, SocketAddress const& address)
{
auto addr = address.to_sockaddr_in();
int rc = ::connect(fd, bit_cast<struct sockaddr*>(&addr), sizeof(addr));
if (rc < 0) {
return Error::from_errno(errno);
}
return {};
}
ErrorOr<size_t> PosixSocketHelper::read(Bytes buffer)
{
if (!is_open()) {
return ENOTCONN;
}
ssize_t rc = ::recv(m_fd, buffer.data(), buffer.size(), 0);
if (rc < 0) {
return Error::from_errno(errno);
}
m_last_read_was_eof = rc == 0;
// If a socket read is EOF, then no more data can be read from it because
// the protocol has disconnected. In this case, we can just disable the
// notifier if we have one.
if (m_last_read_was_eof && m_notifier)
m_notifier->set_enabled(false);
return rc;
}
ErrorOr<size_t> PosixSocketHelper::write(ReadonlyBytes buffer)
{
if (!is_open()) {
return ENOTCONN;
}
ssize_t rc = ::send(m_fd, buffer.data(), buffer.size(), 0);
if (rc < 0) {
return Error::from_errno(errno);
}
return rc;
}
void PosixSocketHelper::close()
{
if (!is_open()) {
return;
}
if (m_notifier)
m_notifier->set_enabled(false);
int rc;
do {
rc = ::close(m_fd);
} while (rc < 0 && errno == EINTR);
VERIFY(rc == 0);
m_fd = -1;
}
ErrorOr<bool> PosixSocketHelper::can_read_without_blocking(int timeout) const
{
struct pollfd the_fd = { .fd = m_fd, .events = POLLIN, .revents = 0 };
int rc;
do {
rc = ::poll(&the_fd, 1, timeout);
} while (rc < 0 && errno == EINTR);
if (rc < 0) {
return Error::from_errno(errno);
}
return (the_fd.revents & POLLIN) > 0;
}
ErrorOr<void> PosixSocketHelper::set_blocking(bool enabled)
{
int value = enabled ? 0 : 1;
int rc = ::ioctl(m_fd, FIONBIO, &value);
if (rc < 0) {
return Error::from_errno(errno);
}
return {};
}
ErrorOr<void> PosixSocketHelper::set_close_on_exec(bool enabled)
{
int flags = ::fcntl(m_fd, F_GETFD);
if (flags < 0)
return Error::from_errno(errno);
if (enabled)
flags |= FD_CLOEXEC;
else
flags &= ~FD_CLOEXEC;
int rc = ::fcntl(m_fd, F_SETFD, flags);
if (rc < 0)
return Error::from_errno(errno);
return {};
}
void PosixSocketHelper::setup_notifier()
{
if (!m_notifier)
m_notifier = Core::Notifier::construct(m_fd, Core::Notifier::Read);
}
Result<TCPSocket, SocketError> TCPSocket::connect(String const& host, u16 port)
{
auto ip_address = TRY(resolve_host(host, SocketType::Stream));
auto maybe_socket = connect(SocketAddress { ip_address, port });
if (maybe_socket.is_error()) {
return SocketError { maybe_socket.release_error() };
}
return maybe_socket.release_value();
}
ErrorOr<TCPSocket> TCPSocket::connect(SocketAddress const& address)
{
TCPSocket socket;
auto fd = TRY(create_fd(SocketDomain::Inet, SocketType::Stream));
socket.m_helper.set_fd(fd);
auto result = connect_inet(fd, address);
if (result.is_error()) {
::close(fd);
return result.release_error();
}
socket.setup_notifier();
return socket;
}
ErrorOr<TCPSocket> TCPSocket::adopt_fd(int fd)
{
if (fd < 0) {
return Error::from_errno(EBADF);
}
TCPSocket socket;
socket.m_helper.set_fd(fd);
socket.setup_notifier();
return socket;
}
ErrorOr<size_t> PosixSocketHelper::pending_bytes() const
{
if (!is_open()) {
return ENOTCONN;
}
int value;
int rc = ::ioctl(m_fd, FIONREAD, &value);
if (rc < 0) {
return Error::from_errno(errno);
}
return static_cast<size_t>(value);
}
Result<UDPSocket, SocketError> UDPSocket::connect(String const& host, u16 port)
{
auto ip_address = TRY(resolve_host(host, SocketType::Datagram));
auto maybe_socket = connect(SocketAddress { ip_address, port });
if (maybe_socket.is_error()) {
return SocketError { maybe_socket.release_error() };
}
return maybe_socket.release_value();
}
ErrorOr<UDPSocket> UDPSocket::connect(SocketAddress const& address)
{
UDPSocket socket;
auto fd = TRY(create_fd(SocketDomain::Inet, SocketType::Datagram));
socket.m_helper.set_fd(fd);
auto result = connect_inet(fd, address);
if (result.is_error()) {
::close(fd);
return result.release_error();
}
socket.setup_notifier();
return socket;
}
ErrorOr<LocalSocket> LocalSocket::connect(String const& path)
{
LocalSocket socket;
auto fd = TRY(create_fd(SocketDomain::Local, SocketType::Stream));
socket.m_helper.set_fd(fd);
auto result = connect_local(fd, path);
if (result.is_error()) {
::close(fd);
return result.release_error();
}
socket.setup_notifier();
return socket;
}
}

View file

@ -0,0 +1,914 @@
/*
* Copyright (c) 2021, sin-ack <sin-ack@protonmail.com>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#pragma once
#include <AK/EnumBits.h>
#include <AK/Function.h>
#include <AK/IPv4Address.h>
#include <AK/Noncopyable.h>
#include <AK/Result.h>
#include <AK/Span.h>
#include <AK/String.h>
#include <AK/Variant.h>
#include <LibCore/Notifier.h>
#include <LibCore/SocketAddress.h>
#include <errno.h>
#include <netdb.h>
namespace Core::Stream {
/// The base, abstract class for stream operations. This class defines the
/// operations one can perform on every stream in LibCore.
class Stream {
public:
virtual bool is_readable() const { return false; }
/// Reads into a buffer, with the maximum size being the size of the buffer.
/// The amount of bytes read can be smaller than the size of the buffer.
/// Returns either the amount of bytes read, or an errno in the case of
/// failure.
virtual ErrorOr<size_t> read(Bytes) = 0;
/// Tries to fill the entire buffer through reading. Returns whether the
/// buffer was filled without an error.
virtual bool read_or_error(Bytes);
virtual bool is_writable() const { return false; }
/// Tries to write the entire contents of the buffer. It is possible for
/// less than the full buffer to be written. Returns either the amount of
/// bytes written into the stream, or an errno in the case of failure.
virtual ErrorOr<size_t> write(ReadonlyBytes) = 0;
/// Same as write, but does not return until either the entire buffer
/// contents are written or an error occurs. Returns whether the entire
/// contents were written without an error.
virtual bool write_or_error(ReadonlyBytes);
/// Returns whether the stream has reached the end of file. For sockets,
/// this most likely means that the protocol has disconnected (in the case
/// of TCP). For seekable streams, this means the end of the file. Note that
/// is_eof will only return true _after_ a read with 0 length, so this
/// method should be called after a read.
virtual bool is_eof() const = 0;
virtual bool is_open() const = 0;
virtual void close() = 0;
virtual ~Stream()
{
}
};
enum class SeekMode {
SetPosition,
FromCurrentPosition,
FromEndPosition,
};
/// Adds seekability to Core::Stream. Classes inheriting from SeekableStream
/// will be seekable to any point in the stream.
class SeekableStream : public Stream {
public:
/// Seeks to the given position in the given mode. Returns either the
/// current position of the file, or an errno in the case of an error.
virtual ErrorOr<off_t> seek(i64 offset, SeekMode) = 0;
/// Returns the current position of the file, or an errno in the case of
/// an error.
virtual ErrorOr<off_t> tell() const;
/// Returns the total size of the stream, or an errno in the case of an
/// error. May not preserve the original position on the stream on failure.
virtual ErrorOr<off_t> size();
};
enum class GetAddrInfoError {
NoAddressInFamily = EAI_ADDRFAMILY,
TemporaryFailure = EAI_AGAIN,
PermanentFailure = EAI_FAIL,
BadFlags = EAI_BADFLAGS,
UnsupportedFamily = EAI_FAMILY,
OutOfMemory = EAI_MEMORY,
NoNetworkAddressesForHost = EAI_NODATA,
UnknownService = EAI_NONAME,
ServiceNotAvailable = EAI_SERVICE,
UnsupportedSocketType = EAI_SOCKTYPE,
System = EAI_SYSTEM,
};
class SocketError {
public:
SocketError(GetAddrInfoError error)
: m_value(error)
{
}
SocketError(ErrorOr<void> const& error)
: m_value(error)
{
}
// TRY() compatibility
SocketError release_error() { return *this; }
void release_value() { }
bool is_error() const
{
return m_value.has<GetAddrInfoError>() || m_value.get<ErrorOr<void>>().is_error();
}
bool is_success() const { return !is_error(); }
bool is_kresult() { return m_value.has<ErrorOr<void>>(); }
bool is_getaddrinfo_error() { return m_value.has<GetAddrInfoError>(); }
ErrorOr<void> as_kresult() { return m_value.get<ErrorOr<void>>(); }
GetAddrInfoError as_getaddrinfo_error()
{
return m_value.get<GetAddrInfoError>();
}
StringView getaddrinfo_error_string()
{
VERIFY(is_getaddrinfo_error());
return { gai_strerror(static_cast<int>(as_getaddrinfo_error())) };
}
private:
Variant<ErrorOr<void>, GetAddrInfoError> m_value;
};
/// The Socket class is the base class for all concrete BSD-style socket
/// classes. Sockets are non-seekable streams which can be read byte-wise.
class Socket : public Stream {
public:
Socket(Socket&&) = default;
Socket& operator=(Socket&&) = default;
/// Checks how many bytes of data are currently available to read on the
/// socket. For datagram-based socket, this is the size of the first
/// datagram that can be read. Returns either the amount of bytes, or an
/// errno in the case of failure.
virtual ErrorOr<size_t> pending_bytes() const = 0;
/// Returns whether there's any data that can be immediately read, or an
/// errno on failure.
virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const = 0;
// Sets the blocking mode of the socket. If blocking mode is disabled, reads
// will fail with EAGAIN when there's no data available to read, and writes
// will fail with EAGAIN when the data cannot be written without blocking
// (due to the send buffer being full, for example).
virtual ErrorOr<void> set_blocking(bool enabled) = 0;
// Sets the close-on-exec mode of the socket. If close-on-exec mode is
// enabled, then the socket will be automatically closed by the kernel when
// an exec call happens.
virtual ErrorOr<void> set_close_on_exec(bool enabled) = 0;
Function<void()> on_ready_to_read;
protected:
enum class SocketDomain {
Local,
Inet,
};
enum class SocketType {
Stream,
Datagram,
};
Socket()
{
}
static ErrorOr<int> create_fd(SocketDomain, SocketType);
// FIXME: This will need to be updated when IPv6 socket arrives. Perhaps a
// base class for all address types is appropriate.
static Result<IPv4Address, SocketError> resolve_host(String const&, SocketType);
static ErrorOr<void> connect_local(int fd, String const& path);
static ErrorOr<void> connect_inet(int fd, SocketAddress const&);
};
/// A reusable socket maintains state about being connected in addition to
/// normal Socket capabilities, and can be reconnected once disconnected.
class ReusableSocket : public Socket {
public:
/// Returns whether the socket is currently connected.
virtual bool is_connected() = 0;
/// Reconnects the socket to the given host and port. Returns EALREADY if
/// is_connected() is true.
virtual SocketError reconnect(String const& host, u16 port) = 0;
/// Connects the socket to the given socket address (IP address + port).
/// Returns EALREADY is_connected() is true.
virtual ErrorOr<void> reconnect(SocketAddress const&) = 0;
};
// Concrete classes.
enum class OpenMode : unsigned {
NotOpen = 0,
Read = 1,
Write = 2,
ReadWrite = 3,
Append = 4,
Truncate = 8,
MustBeNew = 16,
KeepOnExec = 32,
Nonblocking = 64,
};
AK_ENUM_BITWISE_OPERATORS(OpenMode)
class File final : public SeekableStream {
AK_MAKE_NONCOPYABLE(File);
public:
static ErrorOr<File> open(StringView const& filename, OpenMode, mode_t = 0644);
static ErrorOr<File> adopt_fd(int fd, OpenMode);
File(File&& other) { operator=(move(other)); }
File& operator=(File&& other)
{
if (&other == this)
return *this;
m_mode = exchange(other.m_mode, OpenMode::NotOpen);
m_fd = exchange(other.m_fd, -1);
m_last_read_was_eof = exchange(other.m_last_read_was_eof, false);
return *this;
}
virtual bool is_readable() const override;
virtual ErrorOr<size_t> read(Bytes) override;
virtual bool is_writable() const override;
virtual ErrorOr<size_t> write(ReadonlyBytes) override;
virtual bool is_eof() const override;
virtual bool is_open() const override;
virtual void close() override;
virtual ErrorOr<off_t> seek(i64 offset, SeekMode) override;
virtual ~File() override { close(); }
private:
File(OpenMode mode)
: m_mode(mode)
{
}
ErrorOr<void> open_path(StringView const& filename, mode_t);
OpenMode m_mode { OpenMode::NotOpen };
int m_fd { -1 };
bool m_last_read_was_eof { false };
};
class PosixSocketHelper {
AK_MAKE_NONCOPYABLE(PosixSocketHelper);
public:
template<typename T>
PosixSocketHelper(Badge<T>) requires(IsBaseOf<Socket, T>) { }
PosixSocketHelper(PosixSocketHelper&& other)
{
operator=(move(other));
}
PosixSocketHelper& operator=(PosixSocketHelper&& other)
{
m_fd = exchange(other.m_fd, -1);
m_last_read_was_eof = exchange(other.m_last_read_was_eof, false);
m_notifier = move(other.m_notifier);
return *this;
}
int fd() const { return m_fd; }
void set_fd(int fd) { m_fd = fd; }
ErrorOr<size_t> read(Bytes);
ErrorOr<size_t> write(ReadonlyBytes);
bool is_eof() const { return !is_open() || m_last_read_was_eof; }
bool is_open() const { return m_fd != -1; }
void close();
ErrorOr<size_t> pending_bytes() const;
ErrorOr<bool> can_read_without_blocking(int timeout) const;
ErrorOr<void> set_blocking(bool enabled);
ErrorOr<void> set_close_on_exec(bool enabled);
void setup_notifier();
RefPtr<Core::Notifier> notifier() { return m_notifier; }
private:
int m_fd { -1 };
bool m_last_read_was_eof { false };
RefPtr<Core::Notifier> m_notifier;
};
class TCPSocket final : public Socket {
public:
static Result<TCPSocket, SocketError> connect(String const& host, u16 port);
static ErrorOr<TCPSocket> connect(SocketAddress const& address);
static ErrorOr<TCPSocket> adopt_fd(int fd);
TCPSocket(TCPSocket&& other)
: Socket(static_cast<Socket&&>(other))
, m_helper(move(other.m_helper))
{
if (is_open())
setup_notifier();
}
TCPSocket& operator=(TCPSocket&& other)
{
Socket::operator=(static_cast<Socket&&>(other));
m_helper = move(other.m_helper);
if (is_open())
setup_notifier();
return *this;
}
virtual bool is_readable() const override { return is_open(); }
virtual bool is_writable() const override { return is_open(); }
virtual ErrorOr<size_t> read(Bytes buffer) override { return m_helper.read(buffer); }
virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.write(buffer); }
virtual bool is_eof() const override { return m_helper.is_eof(); }
virtual bool is_open() const override { return m_helper.is_open(); };
virtual void close() override { m_helper.close(); };
virtual ErrorOr<size_t> pending_bytes() const override { return m_helper.pending_bytes(); }
virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_helper.can_read_without_blocking(timeout); }
ErrorOr<void> set_blocking(bool enabled) override { return m_helper.set_blocking(enabled); }
ErrorOr<void> set_close_on_exec(bool enabled) override { return m_helper.set_close_on_exec(enabled); }
virtual ~TCPSocket() override { close(); }
private:
TCPSocket()
{
}
void setup_notifier()
{
VERIFY(is_open());
m_helper.setup_notifier();
m_helper.notifier()->on_ready_to_read = [this] {
if (on_ready_to_read)
on_ready_to_read();
};
}
PosixSocketHelper m_helper { Badge<TCPSocket> {} };
};
class UDPSocket final : public Socket {
public:
static Result<UDPSocket, SocketError> connect(String const& host, u16 port);
static ErrorOr<UDPSocket> connect(SocketAddress const& address);
UDPSocket(UDPSocket&& other)
: Socket(static_cast<Socket&&>(other))
, m_helper(move(other.m_helper))
{
if (is_open())
setup_notifier();
}
UDPSocket& operator=(UDPSocket&& other)
{
Socket::operator=(static_cast<Socket&&>(other));
m_helper = move(other.m_helper);
if (is_open())
setup_notifier();
return *this;
}
virtual ErrorOr<size_t> read(Bytes buffer) override
{
auto pending_bytes = TRY(this->pending_bytes());
if (pending_bytes > buffer.size()) {
// With UDP datagrams, reading a datagram into a buffer that's
// smaller than the datagram's size will cause the rest of the
// datagram to be discarded. That's not very nice, so let's bail
// early, telling the caller that he should allocate a bigger
// buffer.
return EMSGSIZE;
}
return m_helper.read(buffer);
}
virtual bool is_readable() const override { return is_open(); }
virtual bool is_writable() const override { return is_open(); }
virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.write(buffer); }
virtual bool is_eof() const override { return m_helper.is_eof(); }
virtual bool is_open() const override { return m_helper.is_open(); }
virtual void close() override { m_helper.close(); }
virtual ErrorOr<size_t> pending_bytes() const override { return m_helper.pending_bytes(); }
virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_helper.can_read_without_blocking(timeout); }
ErrorOr<void> set_blocking(bool enabled) override { return m_helper.set_blocking(enabled); }
ErrorOr<void> set_close_on_exec(bool enabled) override { return m_helper.set_close_on_exec(enabled); }
virtual ~UDPSocket() override { close(); }
private:
UDPSocket() { }
void setup_notifier()
{
VERIFY(is_open());
m_helper.setup_notifier();
m_helper.notifier()->on_ready_to_read = [this] {
if (on_ready_to_read)
on_ready_to_read();
};
}
PosixSocketHelper m_helper { Badge<UDPSocket> {} };
};
class LocalSocket final : public Socket {
public:
static ErrorOr<LocalSocket> connect(String const& path);
LocalSocket(LocalSocket&& other)
: Socket(static_cast<Socket&&>(other))
, m_helper(move(other.m_helper))
{
if (is_open())
setup_notifier();
}
LocalSocket& operator=(LocalSocket&& other)
{
Socket::operator=(static_cast<Socket&&>(other));
m_helper = move(other.m_helper);
if (is_open())
setup_notifier();
return *this;
}
virtual bool is_readable() const override { return is_open(); }
virtual bool is_writable() const override { return is_open(); }
virtual ErrorOr<size_t> read(Bytes buffer) override { return m_helper.read(buffer); }
virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.write(buffer); }
virtual bool is_eof() const override { return m_helper.is_eof(); }
virtual bool is_open() const override { return m_helper.is_open(); }
virtual void close() override { m_helper.close(); }
virtual ErrorOr<size_t> pending_bytes() const override { return m_helper.pending_bytes(); }
virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_helper.can_read_without_blocking(timeout); }
virtual ErrorOr<void> set_blocking(bool enabled) override { return m_helper.set_blocking(enabled); }
virtual ErrorOr<void> set_close_on_exec(bool enabled) override { return m_helper.set_close_on_exec(enabled); }
virtual ~LocalSocket() { close(); }
private:
LocalSocket() { }
void setup_notifier()
{
VERIFY(is_open());
m_helper.setup_notifier();
m_helper.notifier()->on_ready_to_read = [this] {
if (on_ready_to_read)
on_ready_to_read();
};
}
PosixSocketHelper m_helper { Badge<LocalSocket> {} };
};
// Buffered stream wrappers
template<typename T>
concept StreamLike = IsBaseOf<Stream, T>;
template<typename T>
concept SeekableStreamLike = IsBaseOf<SeekableStream, T>;
template<typename T>
concept SocketLike = IsBaseOf<Socket, T>;
template<typename T>
class BufferedHelper {
AK_MAKE_NONCOPYABLE(BufferedHelper);
public:
template<StreamLike U>
BufferedHelper(Badge<U>, T stream, ByteBuffer buffer)
: m_stream(move(stream))
, m_buffer(move(buffer))
{
}
BufferedHelper(BufferedHelper&& other)
: m_stream(move(other.m_stream))
, m_buffer(move(other.m_buffer))
, m_buffered_size(exchange(other.m_buffered_size, 0))
{
}
BufferedHelper& operator=(BufferedHelper&& other)
{
m_stream = move(other.m_stream);
m_buffer = move(other.m_buffer);
m_buffered_size = exchange(other.m_buffered_size, 0);
return *this;
}
template<template<typename> typename BufferedType>
static ErrorOr<BufferedType<T>> create_buffered(T&& stream, size_t buffer_size)
{
if (!buffer_size)
return EINVAL;
if (!stream.is_open())
return ENOTCONN;
auto maybe_buffer = ByteBuffer::create_uninitialized(buffer_size);
if (!maybe_buffer.has_value())
return ENOMEM;
return BufferedType<T> { move(stream), maybe_buffer.release_value() };
}
T& stream() { return m_stream; }
T const& stream() const { return m_stream; }
ErrorOr<size_t> read(Bytes buffer)
{
if (!stream().is_open())
return ENOTCONN;
if (!buffer.size())
return ENOBUFS;
// Let's try to take all we can from the buffer first.
size_t buffer_nread = 0;
if (m_buffered_size > 0) {
// FIXME: Use a circular buffer to avoid shifting the buffer
// contents.
size_t amount_to_take = min(buffer.size(), m_buffered_size);
auto slice_to_take = m_buffer.span().slice(0, amount_to_take);
auto slice_to_shift = m_buffer.span().slice(amount_to_take);
slice_to_take.copy_to(buffer);
buffer_nread += amount_to_take;
if (amount_to_take < m_buffered_size) {
m_buffer.overwrite(0, slice_to_shift.data(), m_buffered_size - amount_to_take);
}
m_buffered_size -= amount_to_take;
}
// If the buffer satisfied the request, then we need not continue.
if (buffer_nread == buffer.size()) {
return buffer_nread;
}
// Otherwise, let's try an extra read just in case there's something
// in our receive buffer.
auto stream_nread = TRY(m_stream.read(buffer.slice(buffer_nread)));
return buffer_nread + stream_nread;
}
// Reads into the buffer until \n is encountered.
// The size of the Bytes object is the maximum amount of bytes that will be
// read. Returns the amount of bytes read.
ErrorOr<size_t> read_line(Bytes buffer)
{
return read_until(buffer, "\n"sv);
}
ErrorOr<size_t> read_until(Bytes buffer, StringView const& candidate)
{
return read_until_any_of(buffer, Array { candidate });
}
template<size_t N>
ErrorOr<size_t> read_until_any_of(Bytes buffer, Array<StringView, N> candidates)
{
if (!stream().is_open())
return ENOTCONN;
if (buffer.is_empty())
return ENOBUFS;
// We fill the buffer through can_read_line.
if (!TRY(can_read_line()))
return 0;
if (stream().is_eof()) {
if (buffer.size() < m_buffered_size) {
// Normally, reading from an EOFed stream and receiving bytes
// would mean that the stream is no longer EOF. However, it's
// possible with a buffered stream that the user is able to read
// the buffer contents even when the underlying stream is EOF.
// We already violate this invariant once by giving the user the
// chance to read the remaining buffer contents, but if the user
// doesn't give us a big enough buffer, then we would be
// violating the invariant twice the next time the user attempts
// to read, which is No Good. So let's give a descriptive error
// to the caller about why it can't read.
return EMSGSIZE;
}
m_buffer.span().slice(0, m_buffered_size).copy_to(buffer);
return exchange(m_buffered_size, 0);
}
size_t longest_match = 0;
size_t maximum_offset = min(m_buffered_size, buffer.size());
for (size_t offset = 0; offset < maximum_offset; offset++) {
// The intention here is to try to match all of the possible
// delimiter candidates and try to find the longest one we can
// remove from the buffer after copying up to the delimiter to the
// user buffer.
StringView remaining_buffer { m_buffer.span().offset(offset), maximum_offset - offset };
for (auto candidate : candidates) {
if (candidate.length() > offset)
continue;
if (remaining_buffer.starts_with(candidate))
longest_match = max(longest_match, candidate.length());
}
if (longest_match > 0) {
auto buffer_to_take = m_buffer.span().slice(0, offset);
auto buffer_to_shift = m_buffer.span().slice(offset + longest_match);
buffer_to_take.copy_to(buffer);
m_buffer.overwrite(0, buffer_to_shift.data(), buffer_to_shift.size());
return offset;
}
}
// If we still haven't found anything, then it's most likely the case
// that the delimiter ends beyond the length of the caller-passed
// buffer. Let's just fill the caller's buffer up.
auto readable_size = min(m_buffered_size, buffer.size());
auto buffer_to_take = m_buffer.span().slice(0, readable_size);
auto buffer_to_shift = m_buffer.span().slice(readable_size);
buffer_to_take.copy_to(buffer);
m_buffer.overwrite(0, buffer_to_shift.data(), buffer_to_shift.size());
return readable_size;
}
// Returns whether a line can be read, populating the buffer in the process.
ErrorOr<bool> can_read_line()
{
if (stream().is_eof() && m_buffered_size > 0)
return true;
if (m_buffer.span().slice(0, m_buffered_size).contains_slow('\n'))
return true;
if (!stream().is_readable())
return false;
while (m_buffered_size < m_buffer.size()) {
auto populated_slice = TRY(populate_read_buffer());
if (stream().is_eof()) {
// We give the user one last hurrah to read the remaining
// contents as a "line".
return m_buffered_size > 0;
}
if (populated_slice.contains_slow('\n'))
return true;
}
return false;
}
bool is_eof() const
{
if (m_buffered_size > 0) {
return false;
}
return stream().is_eof();
}
size_t buffer_size() const
{
return m_buffer.size();
}
void clear_buffer()
{
m_buffered_size = 0;
}
private:
ErrorOr<ReadonlyBytes> populate_read_buffer()
{
if (m_buffered_size == m_buffer.size())
return ReadonlyBytes {};
auto fillable_slice = m_buffer.span().slice(m_buffered_size);
auto nread = TRY(m_stream.read(fillable_slice));
m_buffered_size += nread;
return fillable_slice.slice(0, nread);
}
T m_stream;
// FIXME: Replacing this with a circular buffer would be really nice and
// would avoid excessive copies; however, right now
// AK::CircularDuplexBuffer inlines its entire contents, and that
// would make for a very large object on the stack.
//
// The proper fix is to make a CircularQueue which uses a buffer on
// the heap.
ByteBuffer m_buffer;
size_t m_buffered_size { 0 };
};
// NOTE: A Buffered which accepts any Stream could be added here, but it is not
// needed at the moment.
template<SeekableStreamLike T>
class BufferedSeekable final : public SeekableStream {
friend BufferedHelper<T>;
public:
static ErrorOr<BufferedSeekable<T>> create(T&& stream, size_t buffer_size = 16384)
{
return BufferedHelper<T>::template create_buffered<BufferedSeekable>(move(stream), buffer_size);
}
BufferedSeekable(BufferedSeekable&& other) = default;
BufferedSeekable& operator=(BufferedSeekable&& other) = default;
virtual bool is_readable() const override { return m_helper.stream().is_readable(); }
virtual ErrorOr<size_t> read(Bytes buffer) override { return m_helper.read(move(buffer)); }
virtual bool is_writable() const override { return m_helper.stream().is_writable(); }
virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.stream().write(buffer); }
virtual bool is_eof() const override { return m_helper.is_eof(); }
virtual bool is_open() const override { return m_helper.stream().is_open(); }
virtual void close() override { m_helper.stream().close(); }
virtual ErrorOr<off_t> seek(i64 offset, SeekMode mode) override
{
auto result = TRY(m_helper.stream().seek(offset, mode));
m_helper.clear_buffer();
return result;
}
ErrorOr<size_t> read_line(Bytes buffer) { return m_helper.read_line(move(buffer)); }
ErrorOr<size_t> read_until(Bytes buffer, StringView const& candidate) { return m_helper.read_until(move(buffer), move(candidate)); }
template<size_t N>
ErrorOr<size_t> read_until_any_of(Bytes buffer, Array<StringView, N> candidates) { return m_helper.read_until_any_of(move(buffer), move(candidates)); }
ErrorOr<bool> can_read_line() { return m_helper.can_read_line(); }
size_t buffer_size() const { return m_helper.buffer_size(); }
virtual ~BufferedSeekable() override { }
private:
BufferedSeekable(T stream, ByteBuffer buffer)
: m_helper(Badge<BufferedSeekable<T>> {}, move(stream), buffer)
{
}
BufferedHelper<T> m_helper;
};
template<SocketLike T>
class BufferedSocket final : public Socket {
friend BufferedHelper<T>;
public:
static ErrorOr<BufferedSocket<T>> create(T&& stream, size_t buffer_size = 16384)
{
return BufferedHelper<T>::template create_buffered<BufferedSocket>(move(stream), buffer_size);
}
BufferedSocket(BufferedSocket&& other)
: Socket(static_cast<Socket&&>(other))
, m_helper(move(other.m_helper))
{
setup_notifier();
}
BufferedSocket& operator=(BufferedSocket&& other)
{
Socket::operator=(static_cast<Socket&&>(other));
m_helper = move(other.m_helper);
setup_notifier();
return *this;
}
virtual bool is_readable() const override { return m_helper.stream().is_readable(); }
virtual ErrorOr<size_t> read(Bytes buffer) override { return m_helper.read(move(buffer)); }
virtual bool is_writable() const override { return m_helper.stream().is_writable(); }
virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.stream().write(buffer); }
virtual bool is_eof() const override { return m_helper.is_eof(); }
virtual bool is_open() const override { return m_helper.stream().is_open(); }
virtual void close() override { m_helper.stream().close(); }
virtual ErrorOr<size_t> pending_bytes() const override { return m_helper.stream().pending_bytes(); }
virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_helper.stream().can_read_without_blocking(timeout); }
virtual ErrorOr<void> set_blocking(bool enabled) override { return m_helper.stream().set_blocking(enabled); }
virtual ErrorOr<void> set_close_on_exec(bool enabled) override { return m_helper.stream().set_close_on_exec(enabled); }
ErrorOr<size_t> read_line(Bytes buffer) { return m_helper.read_line(move(buffer)); }
ErrorOr<size_t> read_until(Bytes buffer, StringView const& candidate) { return m_helper.read_until(move(buffer), move(candidate)); }
template<size_t N>
ErrorOr<size_t> read_until_any_of(Bytes buffer, Array<StringView, N> candidates) { return m_helper.read_until_any_of(move(buffer), move(candidates)); }
ErrorOr<bool> can_read_line() { return m_helper.can_read_line(); }
size_t buffer_size() const { return m_helper.buffer_size(); }
virtual ~BufferedSocket() override { }
private:
BufferedSocket(T stream, ByteBuffer buffer)
: m_helper(Badge<BufferedSocket<T>> {}, move(stream), buffer)
{
setup_notifier();
}
void setup_notifier()
{
m_helper.stream().on_ready_to_read = [this] {
if (on_ready_to_read)
on_ready_to_read();
};
}
BufferedHelper<T> m_helper;
};
using BufferedFile = BufferedSeekable<File>;
using BufferedTCPSocket = BufferedSocket<TCPSocket>;
using BufferedUDPSocket = BufferedSocket<UDPSocket>;
using BufferedLocalSocket = BufferedSocket<LocalSocket>;
/// A BasicReusableSocket allows one to use one of the base Core::Stream classes
/// as a ReusableSocket. It does not preserve any connection state or options,
/// and instead just recreates the stream when reconnecting.
template<SocketLike T>
class BasicReusableSocket final : public ReusableSocket {
public:
static Result<BasicReusableSocket<T>, SocketError> connect(String const& host, u16 port)
{
return BasicReusableSocket { TRY(T::connect(host, port)) };
}
static ErrorOr<BasicReusableSocket<T>> connect(SocketAddress const& address)
{
return BasicReusableSocket { TRY(T::connect(address)) };
}
virtual bool is_connected() override
{
return m_socket.is_open();
}
virtual SocketError reconnect(String const& host, u16 port) override
{
if (is_connected())
return SocketError { Error::from_errno(EALREADY) };
m_socket = TRY(T::connect(host, port));
return SocketError { {} };
}
virtual ErrorOr<void> reconnect(SocketAddress const& address) override
{
if (is_connected())
return Error::from_errno(EALREADY);
m_socket = TRY(T::connect(address));
return {};
}
virtual bool is_readable() const override { return m_socket.is_readable(); }
virtual ErrorOr<size_t> read(Bytes buffer) override { return m_socket.read(move(buffer)); }
virtual bool is_writable() const override { return m_socket.is_writable(); }
virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_socket.write(buffer); }
virtual bool is_eof() const override { return m_socket.is_eof(); }
virtual bool is_open() const override { return m_socket.is_open(); }
virtual void close() override { m_socket.close(); }
virtual ErrorOr<size_t> pending_bytes() const override { return m_socket.pending_bytes(); }
virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_socket.can_read_without_blocking(timeout); }
virtual ErrorOr<void> set_blocking(bool enabled) override { return m_socket.set_blocking(enabled); }
virtual ErrorOr<void> set_close_on_exec(bool enabled) override { return m_socket.set_close_on_exec(enabled); }
private:
BasicReusableSocket(T&& socket)
: m_socket(move(socket))
{
}
T m_socket;
};
using ReusableTCPSocket = BasicReusableSocket<TCPSocket>;
using ReusableUDPSocket = BasicReusableSocket<UDPSocket>;
}