1
Fork 0
mirror of https://github.com/RGBCube/serenity synced 2025-05-31 07:48:11 +00:00

LibCompress: Port DeflateDecompressor to Core::Stream

This commit is contained in:
Tim Schumacher 2022-12-02 22:01:44 +01:00 committed by Linus Groh
parent f909cfbe75
commit 30abd47099
8 changed files with 129 additions and 196 deletions

View file

@ -10,6 +10,8 @@
#include <AK/MemoryStream.h> #include <AK/MemoryStream.h>
#include <AK/Random.h> #include <AK/Random.h>
#include <LibCompress/Deflate.h> #include <LibCompress/Deflate.h>
#include <LibCore/InputBitStream.h>
#include <LibCore/MemoryStream.h>
#include <cstring> #include <cstring>
TEST_CASE(canonical_code_simple) TEST_CASE(canonical_code_simple)
@ -27,11 +29,11 @@ TEST_CASE(canonical_code_simple)
}; };
auto const huffman = Compress::CanonicalCode::from_bytes(code).value(); auto const huffman = Compress::CanonicalCode::from_bytes(code).value();
auto memory_stream = InputMemoryStream { input }; auto memory_stream = MUST(Core::Stream::MemoryStream::construct(input));
auto bit_stream = InputBitStream { memory_stream }; auto bit_stream = MUST(Core::Stream::LittleEndianInputBitStream::construct(move(memory_stream)));
for (size_t idx = 0; idx < 9; ++idx) for (size_t idx = 0; idx < 9; ++idx)
EXPECT_EQ(huffman.read_symbol(bit_stream), output[idx]); EXPECT_EQ(MUST(huffman.read_symbol(*bit_stream)), output[idx]);
} }
TEST_CASE(canonical_code_complex) TEST_CASE(canonical_code_complex)
@ -47,11 +49,11 @@ TEST_CASE(canonical_code_complex)
}; };
auto const huffman = Compress::CanonicalCode::from_bytes(code).value(); auto const huffman = Compress::CanonicalCode::from_bytes(code).value();
auto memory_stream = InputMemoryStream { input }; auto memory_stream = MUST(Core::Stream::MemoryStream::construct(input));
auto bit_stream = InputBitStream { memory_stream }; auto bit_stream = MUST(Core::Stream::LittleEndianInputBitStream::construct(move(memory_stream)));
for (size_t idx = 0; idx < 12; ++idx) for (size_t idx = 0; idx < 12; ++idx)
EXPECT_EQ(huffman.read_symbol(bit_stream), output[idx]); EXPECT_EQ(MUST(huffman.read_symbol(*bit_stream)), output[idx]);
} }
TEST_CASE(deflate_decompress_compressed_block) TEST_CASE(deflate_decompress_compressed_block)
@ -118,7 +120,7 @@ TEST_CASE(deflate_round_trip_store)
auto compressed = Compress::DeflateCompressor::compress_all(original, Compress::DeflateCompressor::CompressionLevel::STORE); auto compressed = Compress::DeflateCompressor::compress_all(original, Compress::DeflateCompressor::CompressionLevel::STORE);
EXPECT(compressed.has_value()); EXPECT(compressed.has_value());
auto uncompressed = Compress::DeflateDecompressor::decompress_all(compressed.value()); auto uncompressed = Compress::DeflateDecompressor::decompress_all(compressed.value());
EXPECT(uncompressed.has_value()); EXPECT(!uncompressed.is_error());
EXPECT(uncompressed.value() == original); EXPECT(uncompressed.value() == original);
} }
@ -130,7 +132,7 @@ TEST_CASE(deflate_round_trip_compress)
auto compressed = Compress::DeflateCompressor::compress_all(original, Compress::DeflateCompressor::CompressionLevel::FAST); auto compressed = Compress::DeflateCompressor::compress_all(original, Compress::DeflateCompressor::CompressionLevel::FAST);
EXPECT(compressed.has_value()); EXPECT(compressed.has_value());
auto uncompressed = Compress::DeflateDecompressor::decompress_all(compressed.value()); auto uncompressed = Compress::DeflateDecompressor::decompress_all(compressed.value());
EXPECT(uncompressed.has_value()); EXPECT(!uncompressed.is_error());
EXPECT(uncompressed.value() == original); EXPECT(uncompressed.value() == original);
} }
@ -143,7 +145,7 @@ TEST_CASE(deflate_round_trip_compress_large)
auto compressed = Compress::DeflateCompressor::compress_all(original, Compress::DeflateCompressor::CompressionLevel::FAST); auto compressed = Compress::DeflateCompressor::compress_all(original, Compress::DeflateCompressor::CompressionLevel::FAST);
EXPECT(compressed.has_value()); EXPECT(compressed.has_value());
auto uncompressed = Compress::DeflateDecompressor::decompress_all(compressed.value()); auto uncompressed = Compress::DeflateDecompressor::decompress_all(compressed.value());
EXPECT(uncompressed.has_value()); EXPECT(!uncompressed.is_error());
EXPECT(uncompressed.value() == original); EXPECT(uncompressed.value() == original);
} }

View file

@ -10,6 +10,7 @@
#include <AK/BinaryHeap.h> #include <AK/BinaryHeap.h>
#include <AK/BinarySearch.h> #include <AK/BinarySearch.h>
#include <AK/MemoryStream.h> #include <AK/MemoryStream.h>
#include <LibCore/MemoryStream.h>
#include <string.h> #include <string.h>
#include <LibCompress/Deflate.h> #include <LibCompress/Deflate.h>
@ -98,14 +99,14 @@ Optional<CanonicalCode> CanonicalCode::from_bytes(ReadonlyBytes bytes)
return code; return code;
} }
u32 CanonicalCode::read_symbol(InputBitStream& stream) const ErrorOr<u32> CanonicalCode::read_symbol(Core::Stream::LittleEndianInputBitStream& stream) const
{ {
u32 code_bits = 1; u32 code_bits = 1;
for (;;) { for (;;) {
code_bits = code_bits << 1 | stream.read_bits(1); code_bits = code_bits << 1 | TRY(stream.read_bits(1));
if (code_bits >= (1 << 16)) if (code_bits >= (1 << 16))
return UINT32_MAX; // the maximum symbol in deflate is 288, so we use UINT32_MAX (an impossible value) to indicate an error return Error::from_string_literal("Symbol exceeds maximum symbol number");
// FIXME: This is very inefficient and could greatly be improved by implementing this // FIXME: This is very inefficient and could greatly be improved by implementing this
// algorithm: https://www.hanshq.net/zip.html#huffdec // algorithm: https://www.hanshq.net/zip.html#huffdec
@ -127,17 +128,15 @@ DeflateDecompressor::CompressedBlock::CompressedBlock(DeflateDecompressor& decom
{ {
} }
bool DeflateDecompressor::CompressedBlock::try_read_more() ErrorOr<bool> DeflateDecompressor::CompressedBlock::try_read_more()
{ {
if (m_eof == true) if (m_eof == true)
return false; return false;
auto const symbol = m_literal_codes.read_symbol(m_decompressor.m_input_stream); auto const symbol = TRY(m_literal_codes.read_symbol(*m_decompressor.m_input_stream));
if (symbol >= 286) { // invalid deflate literal/length symbol if (symbol >= 286)
m_decompressor.set_fatal_error(); return Error::from_string_literal("Invalid deflate literal/length symbol");
return false;
}
if (symbol < 256) { if (symbol < 256) {
m_decompressor.m_output_stream << static_cast<u8>(symbol); m_decompressor.m_output_stream << static_cast<u8>(symbol);
@ -146,26 +145,23 @@ bool DeflateDecompressor::CompressedBlock::try_read_more()
m_eof = true; m_eof = true;
return false; return false;
} else { } else {
if (!m_distance_codes.has_value()) { if (!m_distance_codes.has_value())
m_decompressor.set_fatal_error(); return Error::from_string_literal("Distance codes have not been initialized");
return false;
}
auto const length = m_decompressor.decode_length(symbol); auto const length = TRY(m_decompressor.decode_length(symbol));
auto const distance_symbol = m_distance_codes.value().read_symbol(m_decompressor.m_input_stream); auto const distance_symbol = TRY(m_distance_codes.value().read_symbol(*m_decompressor.m_input_stream));
if (distance_symbol >= 30) { // invalid deflate distance symbol if (distance_symbol >= 30)
m_decompressor.set_fatal_error(); return Error::from_string_literal("Invalid deflate distance symbol");
return false;
} auto const distance = TRY(m_decompressor.decode_distance(distance_symbol));
auto const distance = m_decompressor.decode_distance(distance_symbol);
for (size_t idx = 0; idx < length; ++idx) { for (size_t idx = 0; idx < length; ++idx) {
u8 byte = 0; u8 byte = 0;
m_decompressor.m_output_stream.read({ &byte, sizeof(byte) }, distance); m_decompressor.m_output_stream.read({ &byte, sizeof(byte) }, distance);
if (m_decompressor.m_output_stream.handle_any_error()) {
m_decompressor.set_fatal_error(); if (m_decompressor.m_output_stream.handle_any_error())
return false; // a back reference was requested that was too far back (outside our current sliding window) return Error::from_string_literal("A back reference was requested that was too far back");
}
m_decompressor.m_output_stream << byte; m_decompressor.m_output_stream << byte;
} }
@ -179,7 +175,7 @@ DeflateDecompressor::UncompressedBlock::UncompressedBlock(DeflateDecompressor& d
{ {
} }
bool DeflateDecompressor::UncompressedBlock::try_read_more() ErrorOr<bool> DeflateDecompressor::UncompressedBlock::try_read_more()
{ {
if (m_bytes_remaining == 0) if (m_bytes_remaining == 0)
return false; return false;
@ -187,13 +183,13 @@ bool DeflateDecompressor::UncompressedBlock::try_read_more()
auto const nread = min(m_bytes_remaining, m_decompressor.m_output_stream.remaining_contiguous_space()); auto const nread = min(m_bytes_remaining, m_decompressor.m_output_stream.remaining_contiguous_space());
m_bytes_remaining -= nread; m_bytes_remaining -= nread;
m_decompressor.m_input_stream >> m_decompressor.m_output_stream.reserve_contiguous_space(nread); TRY(m_decompressor.m_input_stream->read(m_decompressor.m_output_stream.reserve_contiguous_space(nread)));
return true; return true;
} }
DeflateDecompressor::DeflateDecompressor(InputStream& stream) DeflateDecompressor::DeflateDecompressor(Core::Stream::Handle<Core::Stream::Stream> stream)
: m_input_stream(stream) : m_input_stream(make<Core::Stream::LittleEndianInputBitStream>(move(stream)))
{ {
} }
@ -205,42 +201,28 @@ DeflateDecompressor::~DeflateDecompressor()
m_uncompressed_block.~UncompressedBlock(); m_uncompressed_block.~UncompressedBlock();
} }
size_t DeflateDecompressor::read(Bytes bytes) ErrorOr<Bytes> DeflateDecompressor::read(Bytes bytes)
{ {
size_t total_read = 0; size_t total_read = 0;
while (total_read < bytes.size()) { while (total_read < bytes.size()) {
if (has_any_error())
break;
auto slice = bytes.slice(total_read); auto slice = bytes.slice(total_read);
if (m_state == State::Idle) { if (m_state == State::Idle) {
if (m_read_final_bock) if (m_read_final_bock)
break; break;
m_read_final_bock = m_input_stream.read_bit(); m_read_final_bock = TRY(m_input_stream->read_bit());
auto const block_type = m_input_stream.read_bits(2); auto const block_type = TRY(m_input_stream->read_bits(2));
if (m_input_stream.has_any_error()) {
set_fatal_error();
break;
}
if (block_type == 0b00) { if (block_type == 0b00) {
m_input_stream.align_to_byte_boundary(); m_input_stream->align_to_byte_boundary();
LittleEndian<u16> length, negated_length; LittleEndian<u16> length, negated_length;
m_input_stream >> length >> negated_length; TRY(m_input_stream->read(length.bytes()));
TRY(m_input_stream->read(negated_length.bytes()));
if (m_input_stream.has_any_error()) { if ((length ^ 0xffff) != negated_length)
set_fatal_error(); return Error::from_string_literal("Calculated negated length does not equal stored negated length");
break;
}
if ((length ^ 0xffff) != negated_length) {
set_fatal_error();
break;
}
m_state = State::ReadingUncompressedBlock; m_state = State::ReadingUncompressedBlock;
new (&m_uncompressed_block) UncompressedBlock(*this, length); new (&m_uncompressed_block) UncompressedBlock(*this, length);
@ -258,12 +240,7 @@ size_t DeflateDecompressor::read(Bytes bytes)
if (block_type == 0b10) { if (block_type == 0b10) {
CanonicalCode literal_codes; CanonicalCode literal_codes;
Optional<CanonicalCode> distance_codes; Optional<CanonicalCode> distance_codes;
decode_codes(literal_codes, distance_codes); TRY(decode_codes(literal_codes, distance_codes));
if (m_input_stream.has_any_error()) {
set_fatal_error();
break;
}
m_state = State::ReadingCompressedBlock; m_state = State::ReadingCompressedBlock;
new (&m_compressed_block) CompressedBlock(*this, literal_codes, distance_codes); new (&m_compressed_block) CompressedBlock(*this, literal_codes, distance_codes);
@ -271,22 +248,16 @@ size_t DeflateDecompressor::read(Bytes bytes)
continue; continue;
} }
set_fatal_error(); return Error::from_string_literal("Unhandled block type for Idle state");
break;
} }
if (m_state == State::ReadingCompressedBlock) { if (m_state == State::ReadingCompressedBlock) {
auto nread = m_output_stream.read(slice); auto nread = m_output_stream.read(slice);
while (nread < slice.size() && m_compressed_block.try_read_more()) { while (nread < slice.size() && TRY(m_compressed_block.try_read_more())) {
nread += m_output_stream.read(slice.slice(nread)); nread += m_output_stream.read(slice.slice(nread));
} }
if (m_input_stream.has_any_error()) {
set_fatal_error();
break;
}
total_read += nread; total_read += nread;
if (nread == slice.size()) if (nread == slice.size())
break; break;
@ -300,15 +271,10 @@ size_t DeflateDecompressor::read(Bytes bytes)
if (m_state == State::ReadingUncompressedBlock) { if (m_state == State::ReadingUncompressedBlock) {
auto nread = m_output_stream.read(slice); auto nread = m_output_stream.read(slice);
while (nread < slice.size() && m_uncompressed_block.try_read_more()) { while (nread < slice.size() && TRY(m_uncompressed_block.try_read_more())) {
nread += m_output_stream.read(slice.slice(nread)); nread += m_output_stream.read(slice.slice(nread));
} }
if (m_input_stream.has_any_error()) {
set_fatal_error();
break;
}
total_read += nread; total_read += nread;
if (nread == slice.size()) if (nread == slice.size())
break; break;
@ -321,63 +287,42 @@ size_t DeflateDecompressor::read(Bytes bytes)
VERIFY_NOT_REACHED(); VERIFY_NOT_REACHED();
} }
return total_read;
return bytes.slice(0, total_read);
} }
bool DeflateDecompressor::read_or_error(Bytes bytes) bool DeflateDecompressor::is_eof() const { return m_state == State::Idle && m_read_final_bock; }
{
if (read(bytes) < bytes.size()) {
set_fatal_error();
return false;
}
ErrorOr<size_t> DeflateDecompressor::write(ReadonlyBytes)
{
VERIFY_NOT_REACHED();
}
bool DeflateDecompressor::is_open() const
{
return true; return true;
} }
bool DeflateDecompressor::discard_or_error(size_t count) void DeflateDecompressor::close()
{ {
u8 buffer[4096];
size_t ndiscarded = 0;
while (ndiscarded < count) {
if (unreliable_eof()) {
set_fatal_error();
return false;
}
ndiscarded += read({ buffer, min<size_t>(count - ndiscarded, 4096) });
}
return true;
} }
bool DeflateDecompressor::unreliable_eof() const { return m_state == State::Idle && m_read_final_bock; } ErrorOr<ByteBuffer> DeflateDecompressor::decompress_all(ReadonlyBytes bytes)
bool DeflateDecompressor::handle_any_error()
{ {
bool handled_errors = m_input_stream.handle_any_error(); auto memory_stream = TRY(Core::Stream::MemoryStream::construct(bytes));
return Stream::handle_any_error() || handled_errors; DeflateDecompressor deflate_stream { move(memory_stream) };
}
Optional<ByteBuffer> DeflateDecompressor::decompress_all(ReadonlyBytes bytes)
{
InputMemoryStream memory_stream { bytes };
DeflateDecompressor deflate_stream { memory_stream };
DuplexMemoryStream output_stream; DuplexMemoryStream output_stream;
u8 buffer[4096]; auto buffer = TRY(ByteBuffer::create_uninitialized(4096));
while (!deflate_stream.has_any_error() && !deflate_stream.unreliable_eof()) { while (!deflate_stream.is_eof()) {
auto const nread = deflate_stream.read({ buffer, sizeof(buffer) }); auto const slice = TRY(deflate_stream.read(buffer));
output_stream.write_or_error({ buffer, nread }); output_stream.write_or_error(slice);
} }
if (deflate_stream.handle_any_error())
return {};
return output_stream.copy_into_contiguous_buffer(); return output_stream.copy_into_contiguous_buffer();
} }
u32 DeflateDecompressor::decode_length(u32 symbol) ErrorOr<u32> DeflateDecompressor::decode_length(u32 symbol)
{ {
// FIXME: I can't quite follow the algorithm here, but it seems to work. // FIXME: I can't quite follow the algorithm here, but it seems to work.
@ -386,7 +331,7 @@ u32 DeflateDecompressor::decode_length(u32 symbol)
if (symbol <= 284) { if (symbol <= 284) {
auto extra_bits = (symbol - 261) / 4; auto extra_bits = (symbol - 261) / 4;
return (((symbol - 265) % 4 + 4) << extra_bits) + 3 + m_input_stream.read_bits(extra_bits); return (((symbol - 265) % 4 + 4) << extra_bits) + 3 + TRY(m_input_stream->read_bits(extra_bits));
} }
if (symbol == 285) if (symbol == 285)
@ -395,7 +340,7 @@ u32 DeflateDecompressor::decode_length(u32 symbol)
VERIFY_NOT_REACHED(); VERIFY_NOT_REACHED();
} }
u32 DeflateDecompressor::decode_distance(u32 symbol) ErrorOr<u32> DeflateDecompressor::decode_distance(u32 symbol)
{ {
// FIXME: I can't quite follow the algorithm here, but it seems to work. // FIXME: I can't quite follow the algorithm here, but it seems to work.
@ -404,86 +349,73 @@ u32 DeflateDecompressor::decode_distance(u32 symbol)
if (symbol <= 29) { if (symbol <= 29) {
auto extra_bits = (symbol / 2) - 1; auto extra_bits = (symbol / 2) - 1;
return ((symbol % 2 + 2) << extra_bits) + 1 + m_input_stream.read_bits(extra_bits); return ((symbol % 2 + 2) << extra_bits) + 1 + TRY(m_input_stream->read_bits(extra_bits));
} }
VERIFY_NOT_REACHED(); VERIFY_NOT_REACHED();
} }
void DeflateDecompressor::decode_codes(CanonicalCode& literal_code, Optional<CanonicalCode>& distance_code) ErrorOr<void> DeflateDecompressor::decode_codes(CanonicalCode& literal_code, Optional<CanonicalCode>& distance_code)
{ {
auto literal_code_count = m_input_stream.read_bits(5) + 257; auto literal_code_count = TRY(m_input_stream->read_bits(5)) + 257;
auto distance_code_count = m_input_stream.read_bits(5) + 1; auto distance_code_count = TRY(m_input_stream->read_bits(5)) + 1;
auto code_length_count = m_input_stream.read_bits(4) + 4; auto code_length_count = TRY(m_input_stream->read_bits(4)) + 4;
// First we have to extract the code lengths of the code that was used to encode the code lengths of // First we have to extract the code lengths of the code that was used to encode the code lengths of
// the code that was used to encode the block. // the code that was used to encode the block.
u8 code_lengths_code_lengths[19] = { 0 }; u8 code_lengths_code_lengths[19] = { 0 };
for (size_t i = 0; i < code_length_count; ++i) { for (size_t i = 0; i < code_length_count; ++i) {
code_lengths_code_lengths[code_lengths_code_lengths_order[i]] = m_input_stream.read_bits(3); code_lengths_code_lengths[code_lengths_code_lengths_order[i]] = TRY(m_input_stream->read_bits(3));
} }
// Now we can extract the code that was used to encode the code lengths of the code that was used to // Now we can extract the code that was used to encode the code lengths of the code that was used to
// encode the block. // encode the block.
auto code_length_code_result = CanonicalCode::from_bytes({ code_lengths_code_lengths, sizeof(code_lengths_code_lengths) }); auto code_length_code_result = CanonicalCode::from_bytes({ code_lengths_code_lengths, sizeof(code_lengths_code_lengths) });
if (!code_length_code_result.has_value()) { if (!code_length_code_result.has_value())
set_fatal_error(); return Error::from_string_literal("Failed to decode code length code");
return;
}
auto const code_length_code = code_length_code_result.value(); auto const code_length_code = code_length_code_result.value();
// Next we extract the code lengths of the code that was used to encode the block. // Next we extract the code lengths of the code that was used to encode the block.
Vector<u8> code_lengths; Vector<u8> code_lengths;
while (code_lengths.size() < literal_code_count + distance_code_count) { while (code_lengths.size() < literal_code_count + distance_code_count) {
auto symbol = code_length_code.read_symbol(m_input_stream); auto symbol = TRY(code_length_code.read_symbol(*m_input_stream));
if (symbol == UINT32_MAX) {
set_fatal_error();
return;
}
if (symbol < deflate_special_code_length_copy) { if (symbol < deflate_special_code_length_copy) {
code_lengths.append(static_cast<u8>(symbol)); code_lengths.append(static_cast<u8>(symbol));
continue; continue;
} else if (symbol == deflate_special_code_length_zeros) { } else if (symbol == deflate_special_code_length_zeros) {
auto nrepeat = 3 + m_input_stream.read_bits(3); auto nrepeat = 3 + TRY(m_input_stream->read_bits(3));
for (size_t j = 0; j < nrepeat; ++j) for (size_t j = 0; j < nrepeat; ++j)
code_lengths.append(0); code_lengths.append(0);
continue; continue;
} else if (symbol == deflate_special_code_length_long_zeros) { } else if (symbol == deflate_special_code_length_long_zeros) {
auto nrepeat = 11 + m_input_stream.read_bits(7); auto nrepeat = 11 + TRY(m_input_stream->read_bits(7));
for (size_t j = 0; j < nrepeat; ++j) for (size_t j = 0; j < nrepeat; ++j)
code_lengths.append(0); code_lengths.append(0);
continue; continue;
} else { } else {
VERIFY(symbol == deflate_special_code_length_copy); VERIFY(symbol == deflate_special_code_length_copy);
if (code_lengths.is_empty()) { if (code_lengths.is_empty())
set_fatal_error(); return Error::from_string_literal("Found no codes to copy before a copy block");
return;
}
auto nrepeat = 3 + m_input_stream.read_bits(2); auto nrepeat = 3 + TRY(m_input_stream->read_bits(2));
for (size_t j = 0; j < nrepeat; ++j) for (size_t j = 0; j < nrepeat; ++j)
code_lengths.append(code_lengths.last()); code_lengths.append(code_lengths.last());
} }
} }
if (code_lengths.size() != literal_code_count + distance_code_count) { if (code_lengths.size() != literal_code_count + distance_code_count)
set_fatal_error(); return Error::from_string_literal("Number of code lengths does not match the sum of codes");
return;
}
// Now we extract the code that was used to encode literals and lengths in the block. // Now we extract the code that was used to encode literals and lengths in the block.
auto literal_code_result = CanonicalCode::from_bytes(code_lengths.span().trim(literal_code_count)); auto literal_code_result = CanonicalCode::from_bytes(code_lengths.span().trim(literal_code_count));
if (!literal_code_result.has_value()) { if (!literal_code_result.has_value())
set_fatal_error(); Error::from_string_literal("Failed to decode the literal code");
return;
}
literal_code = literal_code_result.value(); literal_code = literal_code_result.value();
// Now we extract the code that was used to encode distances in the block. // Now we extract the code that was used to encode distances in the block.
@ -491,20 +423,18 @@ void DeflateDecompressor::decode_codes(CanonicalCode& literal_code, Optional<Can
if (distance_code_count == 1) { if (distance_code_count == 1) {
auto length = code_lengths[literal_code_count]; auto length = code_lengths[literal_code_count];
if (length == 0) { if (length == 0)
return; return {};
} else if (length != 1) { else if (length != 1)
set_fatal_error(); return Error::from_string_literal("Length for a single distance code is longer than 1");
return;
}
} }
auto distance_code_result = CanonicalCode::from_bytes(code_lengths.span().slice(literal_code_count)); auto distance_code_result = CanonicalCode::from_bytes(code_lengths.span().slice(literal_code_count));
if (!distance_code_result.has_value()) { if (!distance_code_result.has_value())
set_fatal_error(); Error::from_string_literal("Failed to decode the distance code");
return;
}
distance_code = distance_code_result.value(); distance_code = distance_code_result.value();
return {};
} }
DeflateCompressor::DeflateCompressor(OutputStream& stream, CompressionLevel compression_level) DeflateCompressor::DeflateCompressor(OutputStream& stream, CompressionLevel compression_level)

View file

@ -13,13 +13,15 @@
#include <AK/Endian.h> #include <AK/Endian.h>
#include <AK/Vector.h> #include <AK/Vector.h>
#include <LibCompress/DeflateTables.h> #include <LibCompress/DeflateTables.h>
#include <LibCore/InputBitStream.h>
#include <LibCore/Stream.h>
namespace Compress { namespace Compress {
class CanonicalCode { class CanonicalCode {
public: public:
CanonicalCode() = default; CanonicalCode() = default;
u32 read_symbol(InputBitStream&) const; ErrorOr<u32> read_symbol(Core::Stream::LittleEndianInputBitStream&) const;
void write_symbol(OutputBitStream&, u32) const; void write_symbol(OutputBitStream&, u32) const;
static CanonicalCode const& fixed_literal_codes(); static CanonicalCode const& fixed_literal_codes();
@ -37,13 +39,13 @@ private:
Array<u16, 288> m_bit_code_lengths {}; Array<u16, 288> m_bit_code_lengths {};
}; };
class DeflateDecompressor final : public InputStream { class DeflateDecompressor final : public Core::Stream::Stream {
private: private:
class CompressedBlock { class CompressedBlock {
public: public:
CompressedBlock(DeflateDecompressor&, CanonicalCode literal_codes, Optional<CanonicalCode> distance_codes); CompressedBlock(DeflateDecompressor&, CanonicalCode literal_codes, Optional<CanonicalCode> distance_codes);
bool try_read_more(); ErrorOr<bool> try_read_more();
private: private:
bool m_eof { false }; bool m_eof { false };
@ -57,7 +59,7 @@ private:
public: public:
UncompressedBlock(DeflateDecompressor&, size_t); UncompressedBlock(DeflateDecompressor&, size_t);
bool try_read_more(); ErrorOr<bool> try_read_more();
private: private:
DeflateDecompressor& m_decompressor; DeflateDecompressor& m_decompressor;
@ -74,22 +76,21 @@ public:
friend CompressedBlock; friend CompressedBlock;
friend UncompressedBlock; friend UncompressedBlock;
DeflateDecompressor(InputStream&); DeflateDecompressor(Core::Stream::Handle<Core::Stream::Stream> stream);
~DeflateDecompressor(); ~DeflateDecompressor();
size_t read(Bytes) override; virtual ErrorOr<Bytes> read(Bytes) override;
bool read_or_error(Bytes) override; virtual ErrorOr<size_t> write(ReadonlyBytes) override;
bool discard_or_error(size_t) override; virtual bool is_eof() const override;
virtual bool is_open() const override;
virtual void close() override;
bool unreliable_eof() const override; static ErrorOr<ByteBuffer> decompress_all(ReadonlyBytes);
bool handle_any_error() override;
static Optional<ByteBuffer> decompress_all(ReadonlyBytes);
private: private:
u32 decode_length(u32); ErrorOr<u32> decode_length(u32);
u32 decode_distance(u32); ErrorOr<u32> decode_distance(u32);
void decode_codes(CanonicalCode& literal_code, Optional<CanonicalCode>& distance_code); ErrorOr<void> decode_codes(CanonicalCode& literal_code, Optional<CanonicalCode>& distance_code);
bool m_read_final_bock { false }; bool m_read_final_bock { false };
@ -99,7 +100,7 @@ private:
UncompressedBlock m_uncompressed_block; UncompressedBlock m_uncompressed_block;
}; };
InputBitStream m_input_stream; Core::Stream::Handle<Core::Stream::LittleEndianInputBitStream> m_input_stream;
CircularDuplexStream<32 * KiB> m_output_stream; CircularDuplexStream<32 * KiB> m_output_stream;
}; };

View file

@ -59,14 +59,11 @@ ErrorOr<Bytes> GzipDecompressor::read(Bytes bytes)
auto slice = bytes.slice(total_read); auto slice = bytes.slice(total_read);
if (m_current_member.has_value()) { if (m_current_member.has_value()) {
size_t nread = current_member().m_stream.read(slice); auto current_slice = TRY(current_member().m_stream.read(slice));
current_member().m_checksum.update(slice.trim(nread)); current_member().m_checksum.update(current_slice);
current_member().m_nread += nread; current_member().m_nread += current_slice.size();
if (current_member().m_stream.handle_any_error()) if (current_slice.size() < slice.size()) {
return Error::from_string_literal("Underlying DeflateDecompressor indicated an error");
if (nread < slice.size()) {
LittleEndian<u32> crc32, input_size; LittleEndian<u32> crc32, input_size;
TRY(m_input_stream->read(crc32.bytes())); TRY(m_input_stream->read(crc32.bytes()));
TRY(m_input_stream->read(input_size.bytes())); TRY(m_input_stream->read(input_size.bytes()));
@ -79,11 +76,11 @@ ErrorOr<Bytes> GzipDecompressor::read(Bytes bytes)
m_current_member.clear(); m_current_member.clear();
total_read += nread; total_read += current_slice.size();
continue; continue;
} }
total_read += nread; total_read += current_slice.size();
continue; continue;
} else { } else {
auto current_partial_header_slice = Bytes { m_partial_header, sizeof(BlockHeader) }.slice(m_partial_header_offset); auto current_partial_header_slice = Bytes { m_partial_header, sizeof(BlockHeader) }.slice(m_partial_header_offset);

View file

@ -58,13 +58,11 @@ private:
public: public:
Member(BlockHeader header, Core::Stream::Stream& stream) Member(BlockHeader header, Core::Stream::Stream& stream)
: m_header(header) : m_header(header)
, m_adapted_ak_stream(make<Core::Stream::WrapInAKInputStream>(stream)) , m_stream(Core::Stream::Handle<Core::Stream::Stream>(stream))
, m_stream(*m_adapted_ak_stream)
{ {
} }
BlockHeader m_header; BlockHeader m_header;
NonnullOwnPtr<InputStream> m_adapted_ak_stream;
DeflateDecompressor m_stream; DeflateDecompressor m_stream;
Crypto::Checksum::CRC32 m_checksum; Crypto::Checksum::CRC32 m_checksum;
size_t m_nread { 0 }; size_t m_nread { 0 };

View file

@ -44,7 +44,10 @@ Zlib::Zlib(ZlibHeader header, ReadonlyBytes data)
Optional<ByteBuffer> Zlib::decompress() Optional<ByteBuffer> Zlib::decompress()
{ {
return DeflateDecompressor::decompress_all(m_data_bytes); auto buffer_or_error = DeflateDecompressor::decompress_all(m_data_bytes);
if (buffer_or_error.is_error())
return {};
return buffer_or_error.release_value();
} }
Optional<ByteBuffer> Zlib::decompress_all(ReadonlyBytes bytes) Optional<ByteBuffer> Zlib::decompress_all(ReadonlyBytes bytes)

View file

@ -60,12 +60,14 @@ static Optional<ByteBuffer> handle_content_encoding(ByteBuffer const& buf, Depre
// "Note: Some non-conformant implementations send the "deflate" // "Note: Some non-conformant implementations send the "deflate"
// compressed data without the zlib wrapper." // compressed data without the zlib wrapper."
dbgln_if(JOB_DEBUG, "Job::handle_content_encoding: Zlib::decompress_all() failed. Trying DeflateDecompressor::decompress_all()"); dbgln_if(JOB_DEBUG, "Job::handle_content_encoding: Zlib::decompress_all() failed. Trying DeflateDecompressor::decompress_all()");
uncompressed = Compress::DeflateDecompressor::decompress_all(buf); auto uncompressed_or_error = Compress::DeflateDecompressor::decompress_all(buf);
if (!uncompressed.has_value()) { if (uncompressed_or_error.is_error()) {
dbgln("Job::handle_content_encoding: DeflateDecompressor::decompress_all() failed."); dbgln("Job::handle_content_encoding: DeflateDecompressor::decompress_all() failed: {}", uncompressed_or_error.error());
return {}; return {};
} }
uncompressed = uncompressed_or_error.release_value();
} }
if constexpr (JOB_DEBUG) { if constexpr (JOB_DEBUG) {

View file

@ -48,8 +48,8 @@ static bool unpack_zip_member(Archive::ZipMember zip_member, bool quiet)
} }
case Archive::ZipCompressionMethod::Deflate: { case Archive::ZipCompressionMethod::Deflate: {
auto decompressed_data = Compress::DeflateDecompressor::decompress_all(zip_member.compressed_data); auto decompressed_data = Compress::DeflateDecompressor::decompress_all(zip_member.compressed_data);
if (!decompressed_data.has_value()) { if (decompressed_data.is_error()) {
warnln("Failed decompressing file {}", zip_member.name); warnln("Failed decompressing file {}: {}", zip_member.name, decompressed_data.error());
return false; return false;
} }
if (decompressed_data.value().size() != zip_member.uncompressed_size) { if (decompressed_data.value().size() != zip_member.uncompressed_size) {