diff --git a/AK/BitStream.h b/AK/BitStream.h index 492ba2f180..3c7ef8ba0b 100644 --- a/AK/BitStream.h +++ b/AK/BitStream.h @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -46,6 +47,7 @@ public: { return read_bits(1); } + /// Depending on the number of bits to read, the return type can be chosen appropriately. /// This avoids a bunch of static_cast<>'s for the user. // TODO: Support u128, u256 etc. as well: The concepts would be quite complex. @@ -127,16 +129,29 @@ public: // ^Stream virtual ErrorOr read_some(Bytes bytes) override { - if (m_current_byte.has_value() && is_aligned_to_byte_boundary()) { - bytes[0] = m_current_byte.release_value(); - // FIXME: This accidentally slices off the first byte of the returned span. - return m_stream->read_some(bytes.slice(1)); - } align_to_byte_boundary(); - return m_stream->read_some(bytes); + + size_t bytes_read = 0; + auto buffer = bytes; + + if (m_bit_count > 0) { + auto bits_to_read = min(buffer.size() * bits_per_byte, m_bit_count); + auto result = TRY(read_bits(bits_to_read)); + + bytes_read = bits_to_read / bits_per_byte; + buffer.overwrite(0, &result, bytes_read); + + buffer = buffer.slice(bytes_read); + } + + buffer = TRY(m_stream->read_some(buffer)); + bytes_read += buffer.size(); + + return bytes.trim(bytes_read); } + virtual ErrorOr write_some(ReadonlyBytes bytes) override { return m_stream->write_some(bytes); } - virtual bool is_eof() const override { return m_stream->is_eof() && !m_current_byte.has_value(); } + virtual bool is_eof() const override { return m_stream->is_eof() && m_bit_count == 0; } virtual bool is_open() const override { return m_stream->is_open(); } virtual void close() override { @@ -148,71 +163,103 @@ public: { return read_bits(1); } + /// Depending on the number of bits to read, the return type can be chosen appropriately. /// This avoids a bunch of static_cast<>'s for the user. // TODO: Support u128, u256 etc. as well: The concepts would be quite complex. template ErrorOr read_bits(size_t count) { - if constexpr (IsSame) { - VERIFY(count == 1); - } - T result = 0; - - size_t nread = 0; - while (nread < count) { - if (m_current_byte.has_value()) { - if constexpr (!IsSame && !IsSame) { - // read as many bytes as possible directly - if (((count - nread) >= 8) && is_aligned_to_byte_boundary()) { - // shift existing data over - result |= (m_current_byte.value() << nread); - nread += 8; - m_current_byte.clear(); - } else { - auto const bit = (m_current_byte.value() >> m_bit_offset) & 1; - result |= (bit << nread); - ++nread; - if (m_bit_offset++ == 7) - m_current_byte.clear(); - } - } else { - // Always take this branch for booleans or u8: there's no purpose in reading more than a single bit - auto const bit = (m_current_byte.value() >> m_bit_offset) & 1; - if constexpr (IsSame) - result = bit; - else - result |= (bit << nread); - ++nread; - if (m_bit_offset++ == 7) - m_current_byte.clear(); - } - } else { - m_current_byte = TRY(m_stream->read_value()); - m_bit_offset = 0; - } - } + auto result = TRY(peek_bits(count)); + discard_previously_peeked_bits(count); return result; } + template + ErrorOr peek_bits(size_t count) + { + if (count > m_bit_count) + TRY(refill_buffer_from_stream()); + + return lsb_aligned_buffer() & lsb_mask(min(count, m_bit_count)); + } + + ALWAYS_INLINE void discard_previously_peeked_bits(u8 count) + { + m_bit_offset += count; + m_bit_count -= count; + } + /// Discards any sub-byte stream positioning the input stream may be keeping track of. /// Non-bitwise reads will implicitly call this. u8 align_to_byte_boundary() { - u8 remaining_bits = m_current_byte.value_or(0) >> m_bit_offset; - m_current_byte.clear(); + u8 remaining_bits = 0; + + m_bit_buffer = lsb_aligned_buffer(); m_bit_offset = 0; + + if (auto offset = m_bit_count % bits_per_byte; offset != 0) { + remaining_bits = m_bit_buffer & lsb_mask(offset); + discard_previously_peeked_bits(offset); + } + return remaining_bits; } /// Whether we are (accidentally or intentionally) at a byte boundary right now. - ALWAYS_INLINE bool is_aligned_to_byte_boundary() const { return m_bit_offset == 0; } + ALWAYS_INLINE bool is_aligned_to_byte_boundary() const { return m_bit_count % bits_per_byte == 0; } private: - Optional m_current_byte; - size_t m_bit_offset { 0 }; + using BufferType = u64; + + template + static constexpr T lsb_mask(T bits) + { + constexpr auto max = NumericLimits::max(); + constexpr auto digits = NumericLimits::digits(); + + return bits == 0 ? 0 : max >> (digits - bits); + } + + ALWAYS_INLINE BufferType lsb_aligned_buffer() const + { + return m_bit_offset == bit_buffer_size ? 0 : m_bit_buffer >> m_bit_offset; + } + + ErrorOr refill_buffer_from_stream() + { + size_t bits_to_read = bit_buffer_size - m_bit_count; + size_t bytes_to_read = bits_to_read / bits_per_byte; + + BufferType buffer = 0; + + Bytes bytes { &buffer, bytes_to_read }; + size_t bytes_read = 0; + + // FIXME: When the underlying stream is buffered, `read_some` seems to stop before EOF. + do { + auto result = TRY(m_stream->read_some(bytes)); + bytes = bytes.slice(result.size()); + bytes_read += result.size(); + } while (!bytes.is_empty() && !m_stream->is_eof()); + + m_bit_buffer = (buffer << m_bit_count) | lsb_aligned_buffer(); + m_bit_count += bytes_read * bits_per_byte; + m_bit_offset = 0; + + return {}; + } + + static constexpr size_t bits_per_byte = 8u; + static constexpr size_t bit_buffer_size = sizeof(BufferType) * bits_per_byte; + MaybeOwned m_stream; + + BufferType m_bit_buffer { 0 }; + u8 m_bit_offset { 0 }; + u8 m_bit_count { 0 }; }; /// A stream wrapper class that allows you to write arbitrary amounts of bits