diff --git a/src/uu/hashsum/src/digest.rs b/src/uu/hashsum/src/digest.rs index 531dc7e4f..61f425662 100644 --- a/src/uu/hashsum/src/digest.rs +++ b/src/uu/hashsum/src/digest.rs @@ -189,13 +189,31 @@ pub struct DigestWriter<'a> { /// "\n" before passing input bytes to the [`digest`]. #[allow(dead_code)] binary: bool, - // TODO This is dead code only on non-Windows operating systems. It - // might be better to use a `#[cfg(windows)]` guard here. + + /// Whether the previous + #[allow(dead_code)] + was_last_character_carriage_return: bool, + // TODO These are dead code only on non-Windows operating systems. + // It might be better to use a `#[cfg(windows)]` guard here. } impl<'a> DigestWriter<'a> { pub fn new(digest: &'a mut Box, binary: bool) -> DigestWriter { - DigestWriter { digest, binary } + let was_last_character_carriage_return = false; + DigestWriter { + digest, + binary, + was_last_character_carriage_return, + } + } + + pub fn finalize(&mut self) -> bool { + if self.was_last_character_carriage_return { + self.digest.input(&[b'\r']); + true + } else { + false + } } } @@ -213,22 +231,40 @@ impl<'a> Write for DigestWriter<'a> { return Ok(buf.len()); } - // In Windows text mode, replace each occurrence of "\r\n" - // with "\n". + // The remaining code handles Windows text mode, where we must + // replace each occurrence of "\r\n" with "\n". // - // Find all occurrences of "\r\n", inputting the slice just - // before the "\n" in the previous instance of "\r\n" and - // the beginning of this "\r\n". - // - // FIXME This fails if one call to `write()` ends with the - // "\r" and the next call to `write()` begins with the "\n". + // First, if the last character written was "\r" and the first + // character in the current buffer to write is not "\n", then we + // need to write the "\r" that we buffered from the previous + // call to `write()`. let n = buf.len(); + if self.was_last_character_carriage_return && n > 0 && buf[0] != b'\n' { + self.digest.input(&[b'\r']); + } + + // Next, find all occurrences of "\r\n", inputting the slice + // just before the "\n" in the previous instance of "\r\n" and + // the beginning of this "\r\n". let mut i_prev = 0; for i in memmem::find_iter(buf, b"\r\n") { self.digest.input(&buf[i_prev..i]); i_prev = i + 1; } - self.digest.input(&buf[i_prev..n]); + + // Finally, check whether the last character is "\r". If so, + // buffer it until we know that the next character is not "\n", + // which can only be known on the next call to `write()`. + // + // This all assumes that `write()` will be called on adjacent + // blocks of the input. + if n > 0 && buf[n - 1] == b'\r' { + self.was_last_character_carriage_return = true; + self.digest.input(&buf[i_prev..n - 1]); + } else { + self.was_last_character_carriage_return = false; + self.digest.input(&buf[i_prev..n]); + } // Even though we dropped a "\r" for each "\r\n" we found, we // still report the number of bytes written as `n`. This is @@ -243,3 +279,36 @@ impl<'a> Write for DigestWriter<'a> { Ok(()) } } + +#[cfg(test)] +mod tests { + + /// Test for replacing a "\r\n" sequence with "\n" when the "\r" is + /// at the end of one block and the "\n" is at the beginning of the + /// next block, when reading in blocks. + #[cfg(windows)] + #[test] + fn test_crlf_across_blocks() { + use std::io::Write; + + use crate::digest::Digest; + use crate::digest::DigestWriter; + + // Writing "\r" in one call to `write()`, and then "\n" in another. + let mut digest = Box::new(md5::Context::new()) as Box; + let mut writer_crlf = DigestWriter::new(&mut digest, false); + writer_crlf.write_all(&[b'\r']).unwrap(); + writer_crlf.write_all(&[b'\n']).unwrap(); + writer_crlf.finalize(); + let result_crlf = digest.result_str(); + + // We expect "\r\n" to be replaced with "\n" in text mode on Windows. + let mut digest = Box::new(md5::Context::new()) as Box; + let mut writer_lf = DigestWriter::new(&mut digest, false); + writer_lf.write_all(&[b'\n']).unwrap(); + writer_lf.finalize(); + let result_lf = digest.result_str(); + + assert_eq!(result_crlf, result_lf); + } +} diff --git a/src/uu/hashsum/src/hashsum.rs b/src/uu/hashsum/src/hashsum.rs index 4186043f5..07070ed1b 100644 --- a/src/uu/hashsum/src/hashsum.rs +++ b/src/uu/hashsum/src/hashsum.rs @@ -611,8 +611,16 @@ fn digest_reader( // If `binary` is `false` and the operating system is Windows, then // `DigestWriter` replaces "\r\n" with "\n" before it writes the // bytes into `digest`. Otherwise, it just inserts the bytes as-is. + // + // In order to support replacing "\r\n", we must call `finalize()` + // in order to support the possibility that the last character read + // from the reader was "\r". (This character gets buffered by + // `DigestWriter` and only written if the following character is + // "\n". But when "\r" is the last character read, we need to force + // it to be written.) let mut digest_writer = DigestWriter::new(digest, binary); std::io::copy(reader, &mut digest_writer)?; + digest_writer.finalize(); if digest.output_bits() > 0 { Ok(digest.result_str())