diff --git a/Userland/Libraries/LibCrypto/Checksum/ChecksummingStream.h b/Userland/Libraries/LibCrypto/Checksum/ChecksummingStream.h new file mode 100644 index 0000000000..5c71d00d29 --- /dev/null +++ b/Userland/Libraries/LibCrypto/Checksum/ChecksummingStream.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2023, kleines Filmröllchen + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace Crypto::Checksum { + +// A stream wrapper type which passes all read and written data through a checksum function. +template +requires( + IsBaseOf, ChecksumFunctionType>, + // Require checksum function to be constructible without arguments, since we have no initial data. + requires() { + ChecksumFunctionType {}; + }) +class ChecksummingStream : public Stream { +public: + virtual ~ChecksummingStream() = default; + + ChecksummingStream(MaybeOwned stream) + : m_stream(move(stream)) + { + } + + virtual ErrorOr read_some(Bytes bytes) override + { + auto const written_bytes = TRY(m_stream->read_some(bytes)); + update(written_bytes); + return written_bytes; + } + + virtual ErrorOr read_until_filled(Bytes bytes) override + { + TRY(m_stream->read_until_filled(bytes)); + update(bytes); + return {}; + } + + virtual ErrorOr write_some(ReadonlyBytes bytes) override + { + auto bytes_written = TRY(m_stream->write_some(bytes)); + // Only update with the bytes that were actually written + update(bytes.trim(bytes_written)); + return bytes_written; + } + + virtual ErrorOr write_until_depleted(ReadonlyBytes bytes) override + { + update(bytes); + return m_stream->write_until_depleted(bytes); + } + + virtual bool is_eof() const override { return m_stream->is_eof(); } + virtual bool is_open() const override { return m_stream->is_open(); } + virtual void close() override { m_stream->close(); } + + ChecksumType digest() + { + return m_checksum.digest(); + } + +private: + ALWAYS_INLINE void update(ReadonlyBytes bytes) + { + m_checksum.update(bytes); + } + + MaybeOwned m_stream; + ChecksumFunctionType m_checksum {}; +}; + +}