diff --git a/Userland/Libraries/LibCore/CMakeLists.txt b/Userland/Libraries/LibCore/CMakeLists.txt index d3e415e517..db069de046 100644 --- a/Userland/Libraries/LibCore/CMakeLists.txt +++ b/Userland/Libraries/LibCore/CMakeLists.txt @@ -27,6 +27,7 @@ set(SOURCES Property.cpp SecretString.cpp Socket.cpp + Stream.cpp StandardPaths.cpp System.cpp TCPServer.cpp diff --git a/Userland/Libraries/LibCore/Stream.cpp b/Userland/Libraries/LibCore/Stream.cpp new file mode 100644 index 0000000000..c740f4ac0f --- /dev/null +++ b/Userland/Libraries/LibCore/Stream.cpp @@ -0,0 +1,547 @@ +/* + * Copyright (c) 2018-2021, Andreas Kling + * Copyright (c) 2021, sin-ack + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include "Stream.h" +#include +#include +#include +#include +#include +#include +#include +#ifdef __serenity__ +# include +#endif + +namespace Core::Stream { + +bool Stream::read_or_error(Bytes buffer) +{ + VERIFY(buffer.size()); + + size_t nread = 0; + do { + if (is_eof()) + return false; + + auto result = read(buffer.slice(nread)); + if (result.is_error()) { + if (result.error().is_errno() && result.error().code() == EINTR) { + continue; + } + + return false; + } + + nread += result.value(); + } while (nread < buffer.size()); + + return true; +} + +bool Stream::write_or_error(ReadonlyBytes buffer) +{ + VERIFY(buffer.size()); + + size_t nwritten = 0; + do { + auto result = write(buffer.slice(nwritten)); + if (result.is_error()) { + if (result.error().is_errno() && result.error().code() == EINTR) { + continue; + } + + return false; + } + + nwritten += result.value(); + } while (nwritten < buffer.size()); + + return true; +} + +ErrorOr SeekableStream::tell() const +{ + // Seek with 0 and SEEK_CUR does not modify anything despite the const_cast, + // so it's safe to do this. + return const_cast(this)->seek(0, SeekMode::FromCurrentPosition); +} + +ErrorOr SeekableStream::size() +{ + auto original_position = TRY(tell()); + + auto seek_result = seek(0, SeekMode::FromEndPosition); + if (seek_result.is_error()) { + // Let's try to restore the original position, just in case. + auto restore_result = seek(original_position, SeekMode::SetPosition); + if (restore_result.is_error()) { + dbgln("Core::SeekableStream::size: Couldn't restore initial position, stream might have incorrect position now!"); + } + + return seek_result.release_error(); + } + + TRY(seek(original_position, SeekMode::SetPosition)); + return seek_result.value(); +} + +ErrorOr File::open(StringView const& filename, OpenMode mode, mode_t permissions) +{ + File file { mode }; + TRY(file.open_path(filename, permissions)); + return file; +} + +ErrorOr File::adopt_fd(int fd, OpenMode mode) +{ + if (fd < 0) { + return Error::from_errno(EBADF); + } + + if (!has_any_flag(mode, OpenMode::ReadWrite)) { + dbgln("Core::File::adopt_fd: Attempting to adopt a file with neither Read nor Write specified in mode"); + return Error::from_errno(EINVAL); + } + + File file { mode }; + file.m_fd = fd; + return file; +} + +ErrorOr File::open_path(StringView const& filename, mode_t permissions) +{ + VERIFY(m_fd == -1); + + int flags = 0; + if (has_flag(m_mode, OpenMode::ReadWrite)) { + flags |= O_RDWR | O_CREAT; + } else if (has_flag(m_mode, OpenMode::Read)) { + flags |= O_RDONLY; + } else if (has_flag(m_mode, OpenMode::Write)) { + flags |= O_WRONLY | O_CREAT; + bool should_truncate = !has_any_flag(m_mode, OpenMode::Append | OpenMode::MustBeNew); + if (should_truncate) + flags |= O_TRUNC; + } + + if (has_flag(m_mode, OpenMode::Append)) + flags |= O_APPEND; + if (has_flag(m_mode, OpenMode::Truncate)) + flags |= O_TRUNC; + if (has_flag(m_mode, OpenMode::MustBeNew)) + flags |= O_EXCL; + if (!has_flag(m_mode, OpenMode::KeepOnExec)) + flags |= O_CLOEXEC; + if (!has_flag(m_mode, OpenMode::Nonblocking)) + flags |= O_NONBLOCK; + +#ifdef __serenity__ + int fd = ::serenity_open(filename.characters_without_null_termination(), filename.length(), flags, permissions); +#else + String filename_with_null_terminator { filename }; + int fd = ::open(filename_with_null_terminator.characters(), flags, permissions); +#endif + + if (fd < 0) { + return Error::from_errno(errno); + } + + m_fd = fd; + return {}; +} + +bool File::is_readable() const { return has_flag(m_mode, OpenMode::Read); } +bool File::is_writable() const { return has_flag(m_mode, OpenMode::Write); } + +ErrorOr File::read(Bytes buffer) +{ + if (!has_flag(m_mode, OpenMode::Read)) { + // NOTE: POSIX says that if the fd is not open for reading, the call + // will return EBADF. Since we already know whether we can or + // can't read the file, let's avoid a syscall. + return EBADF; + } + + ssize_t rc = ::read(m_fd, buffer.data(), buffer.size()); + if (rc < 0) { + return Error::from_errno(errno); + } + + m_last_read_was_eof = rc == 0; + return rc; +} + +ErrorOr File::write(ReadonlyBytes buffer) +{ + if (!has_flag(m_mode, OpenMode::Write)) { + // NOTE: Same deal as Read. + return EBADF; + } + + ssize_t rc = ::write(m_fd, buffer.data(), buffer.size()); + if (rc < 0) { + return Error::from_errno(errno); + } + + return rc; +} + +bool File::is_eof() const { return m_last_read_was_eof; } + +bool File::is_open() const { return m_fd >= 0; } + +void File::close() +{ + if (!is_open()) { + return; + } + + // NOTE: The closing of the file can be interrupted by a signal, in which + // case EINTR will be returned by the close syscall. So let's try closing + // the file until we aren't interrupted by rude signals. :^) + int rc; + do { + rc = ::close(m_fd); + } while (rc < 0 && errno == EINTR); + + VERIFY(rc == 0); + m_fd = -1; +} + +ErrorOr File::seek(i64 offset, SeekMode mode) +{ + int syscall_mode; + switch (mode) { + case SeekMode::SetPosition: + syscall_mode = SEEK_SET; + break; + case SeekMode::FromCurrentPosition: + syscall_mode = SEEK_CUR; + break; + case SeekMode::FromEndPosition: + syscall_mode = SEEK_END; + break; + default: + VERIFY_NOT_REACHED(); + } + + off_t rc = lseek(m_fd, offset, syscall_mode); + if (rc < 0) { + return Error::from_errno(errno); + } + + m_last_read_was_eof = false; + return rc; +} + +ErrorOr Socket::create_fd(SocketDomain domain, SocketType type) +{ + int socket_domain; + switch (domain) { + case SocketDomain::Inet: + socket_domain = AF_INET; + break; + case SocketDomain::Local: + socket_domain = AF_LOCAL; + break; + default: + VERIFY_NOT_REACHED(); + } + + int socket_type; + switch (type) { + case SocketType::Stream: + socket_type = SOCK_STREAM; + break; + case SocketType::Datagram: + socket_type = SOCK_DGRAM; + break; + default: + VERIFY_NOT_REACHED(); + } + + int rc = ::socket(socket_domain, socket_type, 0); + if (rc < 0) { + return Error::from_errno(errno); + } + + return rc; +} + +Result Socket::resolve_host(String const& host, SocketType type) +{ + int socket_type; + switch (type) { + case SocketType::Stream: + socket_type = SOCK_STREAM; + break; + case SocketType::Datagram: + socket_type = SOCK_DGRAM; + break; + default: + VERIFY_NOT_REACHED(); + } + + struct addrinfo hints = {}; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = socket_type; + hints.ai_flags = 0; + hints.ai_protocol = 0; + + struct addrinfo* results = nullptr; + int rc = getaddrinfo(host.characters(), nullptr, &hints, &results); + if (rc != 0) { + if (rc == EAI_SYSTEM) { + return SocketError { Error::from_errno(errno) }; + } else { + return SocketError { static_cast(rc) }; + } + } + + auto socket_address = bit_cast(results->ai_addr); + NetworkOrdered network_ordered_address { socket_address->sin_addr.s_addr }; + + freeaddrinfo(results); + + return IPv4Address { network_ordered_address }; +} + +ErrorOr Socket::connect_local(int fd, String const& path) +{ + auto address = SocketAddress::local(path); + auto maybe_sockaddr = address.to_sockaddr_un(); + if (!maybe_sockaddr.has_value()) { + dbgln("Core::Stream::Socket::connect_local: Could not obtain a sockaddr_un"); + return Error::from_errno(EINVAL); + } + + auto addr = maybe_sockaddr.release_value(); + int rc = ::connect(fd, bit_cast(&addr), sizeof(addr)); + if (rc < 0) { + return Error::from_errno(errno); + } + + return {}; +} + +ErrorOr Socket::connect_inet(int fd, SocketAddress const& address) +{ + auto addr = address.to_sockaddr_in(); + int rc = ::connect(fd, bit_cast(&addr), sizeof(addr)); + if (rc < 0) { + return Error::from_errno(errno); + } + + return {}; +} + +ErrorOr PosixSocketHelper::read(Bytes buffer) +{ + if (!is_open()) { + return ENOTCONN; + } + + ssize_t rc = ::recv(m_fd, buffer.data(), buffer.size(), 0); + if (rc < 0) { + return Error::from_errno(errno); + } + + m_last_read_was_eof = rc == 0; + // If a socket read is EOF, then no more data can be read from it because + // the protocol has disconnected. In this case, we can just disable the + // notifier if we have one. + if (m_last_read_was_eof && m_notifier) + m_notifier->set_enabled(false); + + return rc; +} + +ErrorOr PosixSocketHelper::write(ReadonlyBytes buffer) +{ + if (!is_open()) { + return ENOTCONN; + } + + ssize_t rc = ::send(m_fd, buffer.data(), buffer.size(), 0); + if (rc < 0) { + return Error::from_errno(errno); + } + + return rc; +} + +void PosixSocketHelper::close() +{ + if (!is_open()) { + return; + } + + if (m_notifier) + m_notifier->set_enabled(false); + + int rc; + do { + rc = ::close(m_fd); + } while (rc < 0 && errno == EINTR); + + VERIFY(rc == 0); + m_fd = -1; +} + +ErrorOr PosixSocketHelper::can_read_without_blocking(int timeout) const +{ + struct pollfd the_fd = { .fd = m_fd, .events = POLLIN, .revents = 0 }; + + int rc; + do { + rc = ::poll(&the_fd, 1, timeout); + } while (rc < 0 && errno == EINTR); + + if (rc < 0) { + return Error::from_errno(errno); + } + + return (the_fd.revents & POLLIN) > 0; +} + +ErrorOr PosixSocketHelper::set_blocking(bool enabled) +{ + int value = enabled ? 0 : 1; + int rc = ::ioctl(m_fd, FIONBIO, &value); + + if (rc < 0) { + return Error::from_errno(errno); + } + + return {}; +} + +ErrorOr PosixSocketHelper::set_close_on_exec(bool enabled) +{ + int flags = ::fcntl(m_fd, F_GETFD); + if (flags < 0) + return Error::from_errno(errno); + + if (enabled) + flags |= FD_CLOEXEC; + else + flags &= ~FD_CLOEXEC; + + int rc = ::fcntl(m_fd, F_SETFD, flags); + if (rc < 0) + return Error::from_errno(errno); + + return {}; +} + +void PosixSocketHelper::setup_notifier() +{ + if (!m_notifier) + m_notifier = Core::Notifier::construct(m_fd, Core::Notifier::Read); +} + +Result TCPSocket::connect(String const& host, u16 port) +{ + auto ip_address = TRY(resolve_host(host, SocketType::Stream)); + + auto maybe_socket = connect(SocketAddress { ip_address, port }); + if (maybe_socket.is_error()) { + return SocketError { maybe_socket.release_error() }; + } + return maybe_socket.release_value(); +} + +ErrorOr TCPSocket::connect(SocketAddress const& address) +{ + TCPSocket socket; + + auto fd = TRY(create_fd(SocketDomain::Inet, SocketType::Stream)); + socket.m_helper.set_fd(fd); + + auto result = connect_inet(fd, address); + if (result.is_error()) { + ::close(fd); + return result.release_error(); + } + + socket.setup_notifier(); + return socket; +} + +ErrorOr TCPSocket::adopt_fd(int fd) +{ + if (fd < 0) { + return Error::from_errno(EBADF); + } + + TCPSocket socket; + socket.m_helper.set_fd(fd); + socket.setup_notifier(); + return socket; +} + +ErrorOr PosixSocketHelper::pending_bytes() const +{ + if (!is_open()) { + return ENOTCONN; + } + + int value; + int rc = ::ioctl(m_fd, FIONREAD, &value); + if (rc < 0) { + return Error::from_errno(errno); + } + + return static_cast(value); +} + +Result UDPSocket::connect(String const& host, u16 port) +{ + auto ip_address = TRY(resolve_host(host, SocketType::Datagram)); + auto maybe_socket = connect(SocketAddress { ip_address, port }); + if (maybe_socket.is_error()) { + return SocketError { maybe_socket.release_error() }; + } + return maybe_socket.release_value(); +} + +ErrorOr UDPSocket::connect(SocketAddress const& address) +{ + UDPSocket socket; + + auto fd = TRY(create_fd(SocketDomain::Inet, SocketType::Datagram)); + socket.m_helper.set_fd(fd); + + auto result = connect_inet(fd, address); + if (result.is_error()) { + ::close(fd); + return result.release_error(); + } + + socket.setup_notifier(); + return socket; +} + +ErrorOr LocalSocket::connect(String const& path) +{ + LocalSocket socket; + + auto fd = TRY(create_fd(SocketDomain::Local, SocketType::Stream)); + socket.m_helper.set_fd(fd); + + auto result = connect_local(fd, path); + if (result.is_error()) { + ::close(fd); + return result.release_error(); + } + + socket.setup_notifier(); + return socket; +} + +} diff --git a/Userland/Libraries/LibCore/Stream.h b/Userland/Libraries/LibCore/Stream.h new file mode 100644 index 0000000000..b024bebbb6 --- /dev/null +++ b/Userland/Libraries/LibCore/Stream.h @@ -0,0 +1,914 @@ +/* + * Copyright (c) 2021, sin-ack + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Core::Stream { + +/// The base, abstract class for stream operations. This class defines the +/// operations one can perform on every stream in LibCore. +class Stream { +public: + virtual bool is_readable() const { return false; } + /// Reads into a buffer, with the maximum size being the size of the buffer. + /// The amount of bytes read can be smaller than the size of the buffer. + /// Returns either the amount of bytes read, or an errno in the case of + /// failure. + virtual ErrorOr read(Bytes) = 0; + /// Tries to fill the entire buffer through reading. Returns whether the + /// buffer was filled without an error. + virtual bool read_or_error(Bytes); + + virtual bool is_writable() const { return false; } + /// Tries to write the entire contents of the buffer. It is possible for + /// less than the full buffer to be written. Returns either the amount of + /// bytes written into the stream, or an errno in the case of failure. + virtual ErrorOr write(ReadonlyBytes) = 0; + /// Same as write, but does not return until either the entire buffer + /// contents are written or an error occurs. Returns whether the entire + /// contents were written without an error. + virtual bool write_or_error(ReadonlyBytes); + + /// Returns whether the stream has reached the end of file. For sockets, + /// this most likely means that the protocol has disconnected (in the case + /// of TCP). For seekable streams, this means the end of the file. Note that + /// is_eof will only return true _after_ a read with 0 length, so this + /// method should be called after a read. + virtual bool is_eof() const = 0; + + virtual bool is_open() const = 0; + virtual void close() = 0; + + virtual ~Stream() + { + } +}; + +enum class SeekMode { + SetPosition, + FromCurrentPosition, + FromEndPosition, +}; + +/// Adds seekability to Core::Stream. Classes inheriting from SeekableStream +/// will be seekable to any point in the stream. +class SeekableStream : public Stream { +public: + /// Seeks to the given position in the given mode. Returns either the + /// current position of the file, or an errno in the case of an error. + virtual ErrorOr seek(i64 offset, SeekMode) = 0; + /// Returns the current position of the file, or an errno in the case of + /// an error. + virtual ErrorOr tell() const; + /// Returns the total size of the stream, or an errno in the case of an + /// error. May not preserve the original position on the stream on failure. + virtual ErrorOr size(); +}; + +enum class GetAddrInfoError { + NoAddressInFamily = EAI_ADDRFAMILY, + TemporaryFailure = EAI_AGAIN, + PermanentFailure = EAI_FAIL, + BadFlags = EAI_BADFLAGS, + UnsupportedFamily = EAI_FAMILY, + OutOfMemory = EAI_MEMORY, + NoNetworkAddressesForHost = EAI_NODATA, + UnknownService = EAI_NONAME, + ServiceNotAvailable = EAI_SERVICE, + UnsupportedSocketType = EAI_SOCKTYPE, + System = EAI_SYSTEM, +}; + +class SocketError { +public: + SocketError(GetAddrInfoError error) + : m_value(error) + { + } + SocketError(ErrorOr const& error) + : m_value(error) + { + } + + // TRY() compatibility + SocketError release_error() { return *this; } + void release_value() { } + + bool is_error() const + { + return m_value.has() || m_value.get>().is_error(); + } + bool is_success() const { return !is_error(); } + + bool is_kresult() { return m_value.has>(); } + bool is_getaddrinfo_error() { return m_value.has(); } + + ErrorOr as_kresult() { return m_value.get>(); } + GetAddrInfoError as_getaddrinfo_error() + { + return m_value.get(); + } + + StringView getaddrinfo_error_string() + { + VERIFY(is_getaddrinfo_error()); + return { gai_strerror(static_cast(as_getaddrinfo_error())) }; + } + +private: + Variant, GetAddrInfoError> m_value; +}; + +/// The Socket class is the base class for all concrete BSD-style socket +/// classes. Sockets are non-seekable streams which can be read byte-wise. +class Socket : public Stream { +public: + Socket(Socket&&) = default; + Socket& operator=(Socket&&) = default; + + /// Checks how many bytes of data are currently available to read on the + /// socket. For datagram-based socket, this is the size of the first + /// datagram that can be read. Returns either the amount of bytes, or an + /// errno in the case of failure. + virtual ErrorOr pending_bytes() const = 0; + /// Returns whether there's any data that can be immediately read, or an + /// errno on failure. + virtual ErrorOr can_read_without_blocking(int timeout = 0) const = 0; + // Sets the blocking mode of the socket. If blocking mode is disabled, reads + // will fail with EAGAIN when there's no data available to read, and writes + // will fail with EAGAIN when the data cannot be written without blocking + // (due to the send buffer being full, for example). + virtual ErrorOr set_blocking(bool enabled) = 0; + // Sets the close-on-exec mode of the socket. If close-on-exec mode is + // enabled, then the socket will be automatically closed by the kernel when + // an exec call happens. + virtual ErrorOr set_close_on_exec(bool enabled) = 0; + + Function on_ready_to_read; + +protected: + enum class SocketDomain { + Local, + Inet, + }; + + enum class SocketType { + Stream, + Datagram, + }; + + Socket() + { + } + + static ErrorOr create_fd(SocketDomain, SocketType); + // FIXME: This will need to be updated when IPv6 socket arrives. Perhaps a + // base class for all address types is appropriate. + static Result resolve_host(String const&, SocketType); + + static ErrorOr connect_local(int fd, String const& path); + static ErrorOr connect_inet(int fd, SocketAddress const&); +}; + +/// A reusable socket maintains state about being connected in addition to +/// normal Socket capabilities, and can be reconnected once disconnected. +class ReusableSocket : public Socket { +public: + /// Returns whether the socket is currently connected. + virtual bool is_connected() = 0; + /// Reconnects the socket to the given host and port. Returns EALREADY if + /// is_connected() is true. + virtual SocketError reconnect(String const& host, u16 port) = 0; + /// Connects the socket to the given socket address (IP address + port). + /// Returns EALREADY is_connected() is true. + virtual ErrorOr reconnect(SocketAddress const&) = 0; +}; + +// Concrete classes. + +enum class OpenMode : unsigned { + NotOpen = 0, + Read = 1, + Write = 2, + ReadWrite = 3, + Append = 4, + Truncate = 8, + MustBeNew = 16, + KeepOnExec = 32, + Nonblocking = 64, +}; + +AK_ENUM_BITWISE_OPERATORS(OpenMode) + +class File final : public SeekableStream { + AK_MAKE_NONCOPYABLE(File); + +public: + static ErrorOr open(StringView const& filename, OpenMode, mode_t = 0644); + static ErrorOr adopt_fd(int fd, OpenMode); + + File(File&& other) { operator=(move(other)); } + + File& operator=(File&& other) + { + if (&other == this) + return *this; + + m_mode = exchange(other.m_mode, OpenMode::NotOpen); + m_fd = exchange(other.m_fd, -1); + m_last_read_was_eof = exchange(other.m_last_read_was_eof, false); + return *this; + } + + virtual bool is_readable() const override; + virtual ErrorOr read(Bytes) override; + virtual bool is_writable() const override; + virtual ErrorOr write(ReadonlyBytes) override; + virtual bool is_eof() const override; + virtual bool is_open() const override; + virtual void close() override; + virtual ErrorOr seek(i64 offset, SeekMode) override; + + virtual ~File() override { close(); } + +private: + File(OpenMode mode) + : m_mode(mode) + { + } + + ErrorOr open_path(StringView const& filename, mode_t); + + OpenMode m_mode { OpenMode::NotOpen }; + int m_fd { -1 }; + bool m_last_read_was_eof { false }; +}; + +class PosixSocketHelper { + AK_MAKE_NONCOPYABLE(PosixSocketHelper); + +public: + template + PosixSocketHelper(Badge) requires(IsBaseOf) { } + + PosixSocketHelper(PosixSocketHelper&& other) + { + operator=(move(other)); + } + + PosixSocketHelper& operator=(PosixSocketHelper&& other) + { + m_fd = exchange(other.m_fd, -1); + m_last_read_was_eof = exchange(other.m_last_read_was_eof, false); + m_notifier = move(other.m_notifier); + return *this; + } + + int fd() const { return m_fd; } + void set_fd(int fd) { m_fd = fd; } + + ErrorOr read(Bytes); + ErrorOr write(ReadonlyBytes); + + bool is_eof() const { return !is_open() || m_last_read_was_eof; } + bool is_open() const { return m_fd != -1; } + void close(); + + ErrorOr pending_bytes() const; + ErrorOr can_read_without_blocking(int timeout) const; + + ErrorOr set_blocking(bool enabled); + ErrorOr set_close_on_exec(bool enabled); + + void setup_notifier(); + RefPtr notifier() { return m_notifier; } + +private: + int m_fd { -1 }; + bool m_last_read_was_eof { false }; + RefPtr m_notifier; +}; + +class TCPSocket final : public Socket { +public: + static Result connect(String const& host, u16 port); + static ErrorOr connect(SocketAddress const& address); + static ErrorOr adopt_fd(int fd); + + TCPSocket(TCPSocket&& other) + : Socket(static_cast(other)) + , m_helper(move(other.m_helper)) + { + if (is_open()) + setup_notifier(); + } + + TCPSocket& operator=(TCPSocket&& other) + { + Socket::operator=(static_cast(other)); + m_helper = move(other.m_helper); + if (is_open()) + setup_notifier(); + + return *this; + } + + virtual bool is_readable() const override { return is_open(); } + virtual bool is_writable() const override { return is_open(); } + virtual ErrorOr read(Bytes buffer) override { return m_helper.read(buffer); } + virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.write(buffer); } + virtual bool is_eof() const override { return m_helper.is_eof(); } + virtual bool is_open() const override { return m_helper.is_open(); }; + virtual void close() override { m_helper.close(); }; + virtual ErrorOr pending_bytes() const override { return m_helper.pending_bytes(); } + virtual ErrorOr can_read_without_blocking(int timeout = 0) const override { return m_helper.can_read_without_blocking(timeout); } + ErrorOr set_blocking(bool enabled) override { return m_helper.set_blocking(enabled); } + ErrorOr set_close_on_exec(bool enabled) override { return m_helper.set_close_on_exec(enabled); } + + virtual ~TCPSocket() override { close(); } + +private: + TCPSocket() + { + } + + void setup_notifier() + { + VERIFY(is_open()); + + m_helper.setup_notifier(); + m_helper.notifier()->on_ready_to_read = [this] { + if (on_ready_to_read) + on_ready_to_read(); + }; + } + + PosixSocketHelper m_helper { Badge {} }; +}; + +class UDPSocket final : public Socket { +public: + static Result connect(String const& host, u16 port); + static ErrorOr connect(SocketAddress const& address); + + UDPSocket(UDPSocket&& other) + : Socket(static_cast(other)) + , m_helper(move(other.m_helper)) + { + if (is_open()) + setup_notifier(); + } + + UDPSocket& operator=(UDPSocket&& other) + { + Socket::operator=(static_cast(other)); + m_helper = move(other.m_helper); + if (is_open()) + setup_notifier(); + + return *this; + } + + virtual ErrorOr read(Bytes buffer) override + { + auto pending_bytes = TRY(this->pending_bytes()); + if (pending_bytes > buffer.size()) { + // With UDP datagrams, reading a datagram into a buffer that's + // smaller than the datagram's size will cause the rest of the + // datagram to be discarded. That's not very nice, so let's bail + // early, telling the caller that he should allocate a bigger + // buffer. + return EMSGSIZE; + } + + return m_helper.read(buffer); + } + + virtual bool is_readable() const override { return is_open(); } + virtual bool is_writable() const override { return is_open(); } + virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.write(buffer); } + virtual bool is_eof() const override { return m_helper.is_eof(); } + virtual bool is_open() const override { return m_helper.is_open(); } + virtual void close() override { m_helper.close(); } + virtual ErrorOr pending_bytes() const override { return m_helper.pending_bytes(); } + virtual ErrorOr can_read_without_blocking(int timeout = 0) const override { return m_helper.can_read_without_blocking(timeout); } + ErrorOr set_blocking(bool enabled) override { return m_helper.set_blocking(enabled); } + ErrorOr set_close_on_exec(bool enabled) override { return m_helper.set_close_on_exec(enabled); } + + virtual ~UDPSocket() override { close(); } + +private: + UDPSocket() { } + + void setup_notifier() + { + VERIFY(is_open()); + + m_helper.setup_notifier(); + m_helper.notifier()->on_ready_to_read = [this] { + if (on_ready_to_read) + on_ready_to_read(); + }; + } + + PosixSocketHelper m_helper { Badge {} }; +}; + +class LocalSocket final : public Socket { +public: + static ErrorOr connect(String const& path); + + LocalSocket(LocalSocket&& other) + : Socket(static_cast(other)) + , m_helper(move(other.m_helper)) + { + if (is_open()) + setup_notifier(); + } + + LocalSocket& operator=(LocalSocket&& other) + { + Socket::operator=(static_cast(other)); + m_helper = move(other.m_helper); + if (is_open()) + setup_notifier(); + + return *this; + } + + virtual bool is_readable() const override { return is_open(); } + virtual bool is_writable() const override { return is_open(); } + virtual ErrorOr read(Bytes buffer) override { return m_helper.read(buffer); } + virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.write(buffer); } + virtual bool is_eof() const override { return m_helper.is_eof(); } + virtual bool is_open() const override { return m_helper.is_open(); } + virtual void close() override { m_helper.close(); } + virtual ErrorOr pending_bytes() const override { return m_helper.pending_bytes(); } + virtual ErrorOr can_read_without_blocking(int timeout = 0) const override { return m_helper.can_read_without_blocking(timeout); } + virtual ErrorOr set_blocking(bool enabled) override { return m_helper.set_blocking(enabled); } + virtual ErrorOr set_close_on_exec(bool enabled) override { return m_helper.set_close_on_exec(enabled); } + + virtual ~LocalSocket() { close(); } + +private: + LocalSocket() { } + + void setup_notifier() + { + VERIFY(is_open()); + + m_helper.setup_notifier(); + m_helper.notifier()->on_ready_to_read = [this] { + if (on_ready_to_read) + on_ready_to_read(); + }; + } + + PosixSocketHelper m_helper { Badge {} }; +}; + +// Buffered stream wrappers + +template +concept StreamLike = IsBaseOf; +template +concept SeekableStreamLike = IsBaseOf; +template +concept SocketLike = IsBaseOf; + +template +class BufferedHelper { + AK_MAKE_NONCOPYABLE(BufferedHelper); + +public: + template + BufferedHelper(Badge, T stream, ByteBuffer buffer) + : m_stream(move(stream)) + , m_buffer(move(buffer)) + { + } + + BufferedHelper(BufferedHelper&& other) + : m_stream(move(other.m_stream)) + , m_buffer(move(other.m_buffer)) + , m_buffered_size(exchange(other.m_buffered_size, 0)) + { + } + + BufferedHelper& operator=(BufferedHelper&& other) + { + m_stream = move(other.m_stream); + m_buffer = move(other.m_buffer); + m_buffered_size = exchange(other.m_buffered_size, 0); + return *this; + } + + template typename BufferedType> + static ErrorOr> create_buffered(T&& stream, size_t buffer_size) + { + if (!buffer_size) + return EINVAL; + if (!stream.is_open()) + return ENOTCONN; + + auto maybe_buffer = ByteBuffer::create_uninitialized(buffer_size); + if (!maybe_buffer.has_value()) + return ENOMEM; + + return BufferedType { move(stream), maybe_buffer.release_value() }; + } + + T& stream() { return m_stream; } + T const& stream() const { return m_stream; } + + ErrorOr read(Bytes buffer) + { + if (!stream().is_open()) + return ENOTCONN; + if (!buffer.size()) + return ENOBUFS; + + // Let's try to take all we can from the buffer first. + size_t buffer_nread = 0; + if (m_buffered_size > 0) { + // FIXME: Use a circular buffer to avoid shifting the buffer + // contents. + size_t amount_to_take = min(buffer.size(), m_buffered_size); + auto slice_to_take = m_buffer.span().slice(0, amount_to_take); + auto slice_to_shift = m_buffer.span().slice(amount_to_take); + + slice_to_take.copy_to(buffer); + buffer_nread += amount_to_take; + + if (amount_to_take < m_buffered_size) { + m_buffer.overwrite(0, slice_to_shift.data(), m_buffered_size - amount_to_take); + } + m_buffered_size -= amount_to_take; + } + + // If the buffer satisfied the request, then we need not continue. + if (buffer_nread == buffer.size()) { + return buffer_nread; + } + + // Otherwise, let's try an extra read just in case there's something + // in our receive buffer. + auto stream_nread = TRY(m_stream.read(buffer.slice(buffer_nread))); + return buffer_nread + stream_nread; + } + + // Reads into the buffer until \n is encountered. + // The size of the Bytes object is the maximum amount of bytes that will be + // read. Returns the amount of bytes read. + ErrorOr read_line(Bytes buffer) + { + return read_until(buffer, "\n"sv); + } + + ErrorOr read_until(Bytes buffer, StringView const& candidate) + { + return read_until_any_of(buffer, Array { candidate }); + } + + template + ErrorOr read_until_any_of(Bytes buffer, Array candidates) + { + if (!stream().is_open()) + return ENOTCONN; + + if (buffer.is_empty()) + return ENOBUFS; + + // We fill the buffer through can_read_line. + if (!TRY(can_read_line())) + return 0; + + if (stream().is_eof()) { + if (buffer.size() < m_buffered_size) { + // Normally, reading from an EOFed stream and receiving bytes + // would mean that the stream is no longer EOF. However, it's + // possible with a buffered stream that the user is able to read + // the buffer contents even when the underlying stream is EOF. + // We already violate this invariant once by giving the user the + // chance to read the remaining buffer contents, but if the user + // doesn't give us a big enough buffer, then we would be + // violating the invariant twice the next time the user attempts + // to read, which is No Good. So let's give a descriptive error + // to the caller about why it can't read. + return EMSGSIZE; + } + + m_buffer.span().slice(0, m_buffered_size).copy_to(buffer); + return exchange(m_buffered_size, 0); + } + + size_t longest_match = 0; + size_t maximum_offset = min(m_buffered_size, buffer.size()); + for (size_t offset = 0; offset < maximum_offset; offset++) { + // The intention here is to try to match all of the possible + // delimiter candidates and try to find the longest one we can + // remove from the buffer after copying up to the delimiter to the + // user buffer. + StringView remaining_buffer { m_buffer.span().offset(offset), maximum_offset - offset }; + for (auto candidate : candidates) { + if (candidate.length() > offset) + continue; + if (remaining_buffer.starts_with(candidate)) + longest_match = max(longest_match, candidate.length()); + } + + if (longest_match > 0) { + auto buffer_to_take = m_buffer.span().slice(0, offset); + auto buffer_to_shift = m_buffer.span().slice(offset + longest_match); + + buffer_to_take.copy_to(buffer); + m_buffer.overwrite(0, buffer_to_shift.data(), buffer_to_shift.size()); + + return offset; + } + } + + // If we still haven't found anything, then it's most likely the case + // that the delimiter ends beyond the length of the caller-passed + // buffer. Let's just fill the caller's buffer up. + auto readable_size = min(m_buffered_size, buffer.size()); + auto buffer_to_take = m_buffer.span().slice(0, readable_size); + auto buffer_to_shift = m_buffer.span().slice(readable_size); + + buffer_to_take.copy_to(buffer); + m_buffer.overwrite(0, buffer_to_shift.data(), buffer_to_shift.size()); + + return readable_size; + } + + // Returns whether a line can be read, populating the buffer in the process. + ErrorOr can_read_line() + { + if (stream().is_eof() && m_buffered_size > 0) + return true; + + if (m_buffer.span().slice(0, m_buffered_size).contains_slow('\n')) + return true; + + if (!stream().is_readable()) + return false; + + while (m_buffered_size < m_buffer.size()) { + auto populated_slice = TRY(populate_read_buffer()); + + if (stream().is_eof()) { + // We give the user one last hurrah to read the remaining + // contents as a "line". + return m_buffered_size > 0; + } + + if (populated_slice.contains_slow('\n')) + return true; + } + + return false; + } + + bool is_eof() const + { + if (m_buffered_size > 0) { + return false; + } + + return stream().is_eof(); + } + + size_t buffer_size() const + { + return m_buffer.size(); + } + + void clear_buffer() + { + m_buffered_size = 0; + } + +private: + ErrorOr populate_read_buffer() + { + if (m_buffered_size == m_buffer.size()) + return ReadonlyBytes {}; + + auto fillable_slice = m_buffer.span().slice(m_buffered_size); + auto nread = TRY(m_stream.read(fillable_slice)); + m_buffered_size += nread; + return fillable_slice.slice(0, nread); + } + + T m_stream; + // FIXME: Replacing this with a circular buffer would be really nice and + // would avoid excessive copies; however, right now + // AK::CircularDuplexBuffer inlines its entire contents, and that + // would make for a very large object on the stack. + // + // The proper fix is to make a CircularQueue which uses a buffer on + // the heap. + ByteBuffer m_buffer; + size_t m_buffered_size { 0 }; +}; + +// NOTE: A Buffered which accepts any Stream could be added here, but it is not +// needed at the moment. + +template +class BufferedSeekable final : public SeekableStream { + friend BufferedHelper; + +public: + static ErrorOr> create(T&& stream, size_t buffer_size = 16384) + { + return BufferedHelper::template create_buffered(move(stream), buffer_size); + } + + BufferedSeekable(BufferedSeekable&& other) = default; + BufferedSeekable& operator=(BufferedSeekable&& other) = default; + + virtual bool is_readable() const override { return m_helper.stream().is_readable(); } + virtual ErrorOr read(Bytes buffer) override { return m_helper.read(move(buffer)); } + virtual bool is_writable() const override { return m_helper.stream().is_writable(); } + virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.stream().write(buffer); } + virtual bool is_eof() const override { return m_helper.is_eof(); } + virtual bool is_open() const override { return m_helper.stream().is_open(); } + virtual void close() override { m_helper.stream().close(); } + virtual ErrorOr seek(i64 offset, SeekMode mode) override + { + auto result = TRY(m_helper.stream().seek(offset, mode)); + m_helper.clear_buffer(); + return result; + } + + ErrorOr read_line(Bytes buffer) { return m_helper.read_line(move(buffer)); } + ErrorOr read_until(Bytes buffer, StringView const& candidate) { return m_helper.read_until(move(buffer), move(candidate)); } + template + ErrorOr read_until_any_of(Bytes buffer, Array candidates) { return m_helper.read_until_any_of(move(buffer), move(candidates)); } + ErrorOr can_read_line() { return m_helper.can_read_line(); } + + size_t buffer_size() const { return m_helper.buffer_size(); } + + virtual ~BufferedSeekable() override { } + +private: + BufferedSeekable(T stream, ByteBuffer buffer) + : m_helper(Badge> {}, move(stream), buffer) + { + } + + BufferedHelper m_helper; +}; + +template +class BufferedSocket final : public Socket { + friend BufferedHelper; + +public: + static ErrorOr> create(T&& stream, size_t buffer_size = 16384) + { + return BufferedHelper::template create_buffered(move(stream), buffer_size); + } + + BufferedSocket(BufferedSocket&& other) + : Socket(static_cast(other)) + , m_helper(move(other.m_helper)) + { + setup_notifier(); + } + + BufferedSocket& operator=(BufferedSocket&& other) + { + Socket::operator=(static_cast(other)); + m_helper = move(other.m_helper); + + setup_notifier(); + return *this; + } + + virtual bool is_readable() const override { return m_helper.stream().is_readable(); } + virtual ErrorOr read(Bytes buffer) override { return m_helper.read(move(buffer)); } + virtual bool is_writable() const override { return m_helper.stream().is_writable(); } + virtual ErrorOr write(ReadonlyBytes buffer) override { return m_helper.stream().write(buffer); } + virtual bool is_eof() const override { return m_helper.is_eof(); } + virtual bool is_open() const override { return m_helper.stream().is_open(); } + virtual void close() override { m_helper.stream().close(); } + virtual ErrorOr pending_bytes() const override { return m_helper.stream().pending_bytes(); } + virtual ErrorOr can_read_without_blocking(int timeout = 0) const override { return m_helper.stream().can_read_without_blocking(timeout); } + virtual ErrorOr set_blocking(bool enabled) override { return m_helper.stream().set_blocking(enabled); } + virtual ErrorOr set_close_on_exec(bool enabled) override { return m_helper.stream().set_close_on_exec(enabled); } + + ErrorOr read_line(Bytes buffer) { return m_helper.read_line(move(buffer)); } + ErrorOr read_until(Bytes buffer, StringView const& candidate) { return m_helper.read_until(move(buffer), move(candidate)); } + template + ErrorOr read_until_any_of(Bytes buffer, Array candidates) { return m_helper.read_until_any_of(move(buffer), move(candidates)); } + ErrorOr can_read_line() { return m_helper.can_read_line(); } + + size_t buffer_size() const { return m_helper.buffer_size(); } + + virtual ~BufferedSocket() override { } + +private: + BufferedSocket(T stream, ByteBuffer buffer) + : m_helper(Badge> {}, move(stream), buffer) + { + setup_notifier(); + } + + void setup_notifier() + { + m_helper.stream().on_ready_to_read = [this] { + if (on_ready_to_read) + on_ready_to_read(); + }; + } + + BufferedHelper m_helper; +}; + +using BufferedFile = BufferedSeekable; +using BufferedTCPSocket = BufferedSocket; +using BufferedUDPSocket = BufferedSocket; +using BufferedLocalSocket = BufferedSocket; + +/// A BasicReusableSocket allows one to use one of the base Core::Stream classes +/// as a ReusableSocket. It does not preserve any connection state or options, +/// and instead just recreates the stream when reconnecting. +template +class BasicReusableSocket final : public ReusableSocket { +public: + static Result, SocketError> connect(String const& host, u16 port) + { + return BasicReusableSocket { TRY(T::connect(host, port)) }; + } + + static ErrorOr> connect(SocketAddress const& address) + { + return BasicReusableSocket { TRY(T::connect(address)) }; + } + + virtual bool is_connected() override + { + return m_socket.is_open(); + } + + virtual SocketError reconnect(String const& host, u16 port) override + { + if (is_connected()) + return SocketError { Error::from_errno(EALREADY) }; + + m_socket = TRY(T::connect(host, port)); + return SocketError { {} }; + } + + virtual ErrorOr reconnect(SocketAddress const& address) override + { + if (is_connected()) + return Error::from_errno(EALREADY); + + m_socket = TRY(T::connect(address)); + return {}; + } + + virtual bool is_readable() const override { return m_socket.is_readable(); } + virtual ErrorOr read(Bytes buffer) override { return m_socket.read(move(buffer)); } + virtual bool is_writable() const override { return m_socket.is_writable(); } + virtual ErrorOr write(ReadonlyBytes buffer) override { return m_socket.write(buffer); } + virtual bool is_eof() const override { return m_socket.is_eof(); } + virtual bool is_open() const override { return m_socket.is_open(); } + virtual void close() override { m_socket.close(); } + virtual ErrorOr pending_bytes() const override { return m_socket.pending_bytes(); } + virtual ErrorOr can_read_without_blocking(int timeout = 0) const override { return m_socket.can_read_without_blocking(timeout); } + virtual ErrorOr set_blocking(bool enabled) override { return m_socket.set_blocking(enabled); } + virtual ErrorOr set_close_on_exec(bool enabled) override { return m_socket.set_close_on_exec(enabled); } + +private: + BasicReusableSocket(T&& socket) + : m_socket(move(socket)) + { + } + + T m_socket; +}; + +using ReusableTCPSocket = BasicReusableSocket; +using ReusableUDPSocket = BasicReusableSocket; + +}