diff --git a/AK/BitStream.h b/AK/BitStream.h index 373e6a2447..3ddca57267 100644 --- a/AK/BitStream.h +++ b/AK/BitStream.h @@ -117,9 +117,18 @@ private: MaybeOwned m_stream; }; -/// A stream wrapper class that allows you to read arbitrary amounts of bits -/// in little-endian order from another stream. -class LittleEndianInputBitStream : public Stream { +class LittleEndianBitStream : public Stream { +protected: + using BufferType = u64; + + static constexpr size_t bits_per_byte = 8u; + static constexpr size_t bit_buffer_size = sizeof(BufferType) * bits_per_byte; + + explicit LittleEndianBitStream(MaybeOwned stream) + : m_stream(move(stream)) + { + } + template static constexpr T lsb_mask(T bits) { @@ -129,9 +138,26 @@ class LittleEndianInputBitStream : public Stream { return bits == 0 ? 0 : max >> (digits - bits); } + ALWAYS_INLINE BufferType lsb_aligned_buffer() const + { + return m_bit_offset == bit_buffer_size ? 0 : m_bit_buffer >> m_bit_offset; + } + + ALWAYS_INLINE bool is_aligned_to_byte_boundary() const { return m_bit_count % bits_per_byte == 0; } + + MaybeOwned m_stream; + + BufferType m_bit_buffer { 0 }; + u8 m_bit_offset { 0 }; + u8 m_bit_count { 0 }; +}; + +/// A stream wrapper class that allows you to read arbitrary amounts of bits +/// in little-endian order from another stream. +class LittleEndianInputBitStream : public LittleEndianBitStream { public: explicit LittleEndianInputBitStream(MaybeOwned stream) - : m_stream(move(stream)) + : LittleEndianBitStream(move(stream)) { } @@ -217,17 +243,7 @@ public: return remaining_bits; } - /// Whether we are (accidentally or intentionally) at a byte boundary right now. - ALWAYS_INLINE bool is_aligned_to_byte_boundary() const { return m_bit_count % bits_per_byte == 0; } - private: - using BufferType = u64; - - ALWAYS_INLINE BufferType lsb_aligned_buffer() const - { - return m_bit_offset == bit_buffer_size ? 0 : m_bit_buffer >> m_bit_offset; - } - ErrorOr refill_buffer_from_stream() { size_t bits_to_read = bit_buffer_size - m_bit_count; @@ -242,15 +258,6 @@ private: return {}; } - - static constexpr size_t bits_per_byte = 8u; - static constexpr size_t bit_buffer_size = sizeof(BufferType) * bits_per_byte; - - MaybeOwned m_stream; - - BufferType m_bit_buffer { 0 }; - u8 m_bit_offset { 0 }; - u8 m_bit_count { 0 }; }; /// A stream wrapper class that allows you to write arbitrary amounts of bits @@ -333,10 +340,10 @@ private: /// A stream wrapper class that allows you to write arbitrary amounts of bits /// in little-endian order to another stream. -class LittleEndianOutputBitStream : public Stream { +class LittleEndianOutputBitStream : public LittleEndianBitStream { public: explicit LittleEndianOutputBitStream(MaybeOwned stream) - : m_stream(move(stream)) + : LittleEndianBitStream(move(stream)) { } @@ -347,28 +354,52 @@ public: virtual ErrorOr write_some(ReadonlyBytes bytes) override { - VERIFY(m_bit_offset == 0); + VERIFY(is_aligned_to_byte_boundary()); + + if (m_bit_count > 0) + TRY(flush_buffer_to_stream()); + return m_stream->write_some(bytes); } template - ErrorOr write_bits(T value, size_t bit_count) + ErrorOr write_bits(T value, size_t count) { - VERIFY(m_bit_offset <= 7); + if (m_bit_count == bit_buffer_size) { + TRY(flush_buffer_to_stream()); + } else if (auto remaining = bit_buffer_size - m_bit_count; count >= remaining) { + m_bit_buffer |= (static_cast(value) & lsb_mask(remaining)) << m_bit_count; + m_bit_count = bit_buffer_size; - size_t input_offset = 0; - while (input_offset < bit_count) { - u8 next_bit = (value >> input_offset) & 1; - input_offset++; + if (remaining != sizeof(value) * bits_per_byte) + value >>= remaining; + count -= remaining; - m_current_byte |= next_bit << m_bit_offset; - m_bit_offset++; + TRY(flush_buffer_to_stream()); + } - if (m_bit_offset > 7) { - TRY(m_stream->write_value(m_current_byte)); - m_bit_offset = 0; - m_current_byte = 0; - } + if (count == 0) + return {}; + + m_bit_buffer |= static_cast(value) << m_bit_count; + m_bit_count += count; + + return {}; + } + + ALWAYS_INLINE ErrorOr flush_buffer_to_stream() + { + auto bytes_to_write = m_bit_count / bits_per_byte; + TRY(m_stream->write_until_depleted({ &m_bit_buffer, bytes_to_write })); + + if (m_bit_count == bit_buffer_size) { + m_bit_buffer = 0; + m_bit_count = 0; + } else { + auto bits_written = bytes_to_write * bits_per_byte; + + m_bit_buffer >>= bits_written; + m_bit_count -= bits_written; } return {}; @@ -390,23 +421,16 @@ public: size_t bit_offset() const { - return m_bit_offset; + return m_bit_count; } ErrorOr align_to_byte_boundary() { - if (m_bit_offset == 0) - return {}; + if (auto offset = m_bit_count % bits_per_byte; offset != 0) + TRY(write_bits(0u, bits_per_byte - offset)); - TRY(write_bits(0u, 8 - m_bit_offset)); - VERIFY(m_bit_offset == 0); return {}; } - -private: - MaybeOwned m_stream; - u8 m_current_byte { 0 }; - size_t m_bit_offset { 0 }; }; } diff --git a/Tests/AK/TestBitStream.cpp b/Tests/AK/TestBitStream.cpp index a76b4a778a..e67a00659d 100644 --- a/Tests/AK/TestBitStream.cpp +++ b/Tests/AK/TestBitStream.cpp @@ -22,6 +22,8 @@ TEST_CASE(little_endian_bit_stream_input_output_match) { MUST(bit_write_stream.write_bits(0b1111u, 4)); MUST(bit_write_stream.write_bits(0b1111u, 4)); + MUST(bit_write_stream.flush_buffer_to_stream()); + auto result = MUST(bit_read_stream.read_bits(4)); EXPECT_EQ(0b1111u, result); result = MUST(bit_read_stream.read_bits(4)); @@ -30,6 +32,8 @@ TEST_CASE(little_endian_bit_stream_input_output_match) { MUST(bit_write_stream.write_bits(0b0000u, 4)); MUST(bit_write_stream.write_bits(0b0000u, 4)); + MUST(bit_write_stream.flush_buffer_to_stream()); + auto result = MUST(bit_read_stream.read_bits(4)); EXPECT_EQ(0b0000u, result); result = MUST(bit_read_stream.read_bits(4)); @@ -40,6 +44,8 @@ TEST_CASE(little_endian_bit_stream_input_output_match) { MUST(bit_write_stream.write_bits(0b1000u, 4)); MUST(bit_write_stream.write_bits(0b1000u, 4)); + MUST(bit_write_stream.flush_buffer_to_stream()); + auto result = MUST(bit_read_stream.read_bits(4)); EXPECT_EQ(0b1000u, result); result = MUST(bit_read_stream.read_bits(4)); @@ -50,6 +56,8 @@ TEST_CASE(little_endian_bit_stream_input_output_match) { MUST(bit_write_stream.write_bits(0b1000u, 4)); MUST(bit_write_stream.write_bits(0b0100u, 4)); + MUST(bit_write_stream.flush_buffer_to_stream()); + auto result = MUST(bit_read_stream.read_bits(4)); EXPECT_EQ(0b1000u, result); result = MUST(bit_read_stream.read_bits(4)); @@ -59,6 +67,8 @@ TEST_CASE(little_endian_bit_stream_input_output_match) // Test a pattern that spans multiple bytes. { MUST(bit_write_stream.write_bits(0b1101001000100001u, 16)); + MUST(bit_write_stream.flush_buffer_to_stream()); + auto result = MUST(bit_read_stream.read_bits(16)); EXPECT_EQ(0b1101001000100001u, result); } diff --git a/Userland/Libraries/LibCompress/Deflate.cpp b/Userland/Libraries/LibCompress/Deflate.cpp index e48ce01f8c..03f6e5f7d0 100644 --- a/Userland/Libraries/LibCompress/Deflate.cpp +++ b/Userland/Libraries/LibCompress/Deflate.cpp @@ -1050,6 +1050,7 @@ ErrorOr DeflateCompressor::final_flush() VERIFY(!m_finished); m_finished = true; TRY(flush()); + TRY(m_output_stream->flush_buffer_to_stream()); return {}; }