From 3870ee252ae884cc1dba53d57ee74b34c8703467 Mon Sep 17 00:00:00 2001 From: Jed Denlea Date: Sun, 14 May 2023 15:29:54 -0700 Subject: [PATCH] yes: support non-UTF-8 args Also, tighten the creation of the output buffer. Rather than copy "y\n" 8192 times, or any other input some number of times, it can be doubled in place using Vec::extend_from_within. --- Cargo.lock | 1 + src/uu/yes/Cargo.toml | 1 + src/uu/yes/src/yes.rs | 161 ++++++++++++++++++++++++++++++++------ tests/by-util/test_yes.rs | 26 +++++- 4 files changed, 161 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5f8bddf94..6abb5014a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3379,6 +3379,7 @@ name = "uu_yes" version = "0.0.18" dependencies = [ "clap", + "itertools", "nix", "uucore", ] diff --git a/src/uu/yes/Cargo.toml b/src/uu/yes/Cargo.toml index fd3d9ddc0..9d661fb0d 100644 --- a/src/uu/yes/Cargo.toml +++ b/src/uu/yes/Cargo.toml @@ -16,6 +16,7 @@ path = "src/yes.rs" [dependencies] clap = { workspace=true } +itertools = { workspace=true } [target.'cfg(unix)'.dependencies] uucore = { workspace=true, features=["pipes", "signals"] } diff --git a/src/uu/yes/src/yes.rs b/src/uu/yes/src/yes.rs index 41bfeddca..fd5124064 100644 --- a/src/uu/yes/src/yes.rs +++ b/src/uu/yes/src/yes.rs @@ -7,8 +7,11 @@ /* last synced with: yes (GNU coreutils) 8.13 */ -use clap::{Arg, ArgAction, Command}; -use std::borrow::Cow; +// cSpell:ignore strs + +use clap::{builder::ValueParser, Arg, ArgAction, Command}; +use std::error::Error; +use std::ffi::OsString; use std::io::{self, Write}; use uucore::error::{UResult, USimpleError}; #[cfg(unix)] @@ -28,19 +31,11 @@ const BUF_SIZE: usize = 16 * 1024; pub fn uumain(args: impl uucore::Args) -> UResult<()> { let matches = uu_app().try_get_matches_from(args)?; - let string = if let Some(values) = matches.get_many::("STRING") { - let mut result = values.fold(String::new(), |res, s| res + s + " "); - result.pop(); - result.push('\n'); - Cow::from(result) - } else { - Cow::from("y\n") - }; + let mut buffer = Vec::with_capacity(BUF_SIZE); + args_into_buffer(&mut buffer, matches.get_many::("STRING")).unwrap(); + prepare_buffer(&mut buffer); - let mut buffer = [0; BUF_SIZE]; - let bytes = prepare_buffer(&string, &mut buffer); - - match exec(bytes) { + match exec(&buffer) { Ok(()) => Ok(()), Err(err) if err.kind() == io::ErrorKind::BrokenPipe => Ok(()), Err(err) => Err(USimpleError::new(1, format!("standard output: {err}"))), @@ -51,21 +46,73 @@ pub fn uu_app() -> Command { Command::new(uucore::util_name()) .about(ABOUT) .override_usage(format_usage(USAGE)) - .arg(Arg::new("STRING").action(ArgAction::Append)) + .arg( + Arg::new("STRING") + .value_parser(ValueParser::os_string()) + .action(ArgAction::Append), + ) .infer_long_args(true) } -fn prepare_buffer<'a>(input: &'a str, buffer: &'a mut [u8; BUF_SIZE]) -> &'a [u8] { - if input.len() < BUF_SIZE / 2 { - let mut size = 0; - while size < BUF_SIZE - input.len() { - let (_, right) = buffer.split_at_mut(size); - right[..input.len()].copy_from_slice(input.as_bytes()); - size += input.len(); - } - &buffer[..size] +// Copies words from `i` into `buf`, separated by spaces. +fn args_into_buffer<'a>( + buf: &mut Vec, + i: Option>, +) -> Result<(), Box> { + // TODO: this should be replaced with let/else once available in the MSRV. + let i = if let Some(i) = i { + i } else { - input.as_bytes() + buf.extend_from_slice(b"y\n"); + return Ok(()); + }; + + // On Unix (and wasi), OsStrs are just &[u8]'s underneath... + #[cfg(any(unix, target_os = "wasi"))] + { + #[cfg(unix)] + use std::os::unix::ffi::OsStrExt; + #[cfg(target_os = "wasi")] + use std::os::wasi::ffi::OsStrExt; + + for part in itertools::intersperse(i.map(|a| a.as_bytes()), b" ") { + buf.extend_from_slice(part); + } + } + + // But, on Windows, we must hop through a String. + #[cfg(not(any(unix, target_os = "wasi")))] + { + for part in itertools::intersperse(i.map(|a| a.to_str()), Some(" ")) { + let bytes = match part { + Some(part) => part.as_bytes(), + None => return Err("arguments contain invalid UTF-8".into()), + }; + buf.extend_from_slice(bytes); + } + } + + buf.push(b'\n'); + + Ok(()) +} + +// Assumes buf holds a single output line forged from the command line arguments, copies it +// repeatedly until the buffer holds as many copies as it can under BUF_SIZE. +fn prepare_buffer(buf: &mut Vec) { + if buf.len() * 2 > BUF_SIZE { + return; + } + + assert!(!buf.is_empty()); + + let line_len = buf.len(); + let target_size = line_len * (BUF_SIZE / line_len); + + while buf.len() < target_size { + let to_copy = std::cmp::min(target_size - buf.len(), buf.len()); + debug_assert_eq!(to_copy % line_len, 0); + buf.extend_from_within(..to_copy); } } @@ -88,3 +135,67 @@ pub fn exec(bytes: &[u8]) -> io::Result<()> { stdout.write_all(bytes)?; } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_prepare_buffer() { + let tests = [ + (150, 16350), + (1000, 16000), + (4093, 16372), + (4099, 12297), + (4111, 12333), + (2, 16384), + (3, 16383), + (4, 16384), + (5, 16380), + (8192, 16384), + (8191, 16382), + (8193, 8193), + (10000, 10000), + (15000, 15000), + (25000, 25000), + ]; + + for (line, final_len) in tests { + let mut v = std::iter::repeat(b'a').take(line).collect::>(); + prepare_buffer(&mut v); + assert_eq!(v.len(), final_len); + } + } + + #[test] + fn test_args_into_buf() { + { + let mut v = Vec::with_capacity(BUF_SIZE); + args_into_buffer(&mut v, None::>).unwrap(); + assert_eq!(String::from_utf8(v).unwrap(), "y\n"); + } + + { + let mut v = Vec::with_capacity(BUF_SIZE); + args_into_buffer(&mut v, Some([OsString::from("foo")].iter())).unwrap(); + assert_eq!(String::from_utf8(v).unwrap(), "foo\n"); + } + + { + let mut v = Vec::with_capacity(BUF_SIZE); + args_into_buffer( + &mut v, + Some( + [ + OsString::from("foo"), + OsString::from("bar baz"), + OsString::from("qux"), + ] + .iter(), + ), + ) + .unwrap(); + assert_eq!(String::from_utf8(v).unwrap(), "foo bar baz qux\n"); + } + } +} diff --git a/tests/by-util/test_yes.rs b/tests/by-util/test_yes.rs index c054a6e5f..89a68e7e1 100644 --- a/tests/by-util/test_yes.rs +++ b/tests/by-util/test_yes.rs @@ -1,3 +1,4 @@ +use std::ffi::OsStr; use std::process::{ExitStatus, Stdio}; #[cfg(unix)] @@ -15,8 +16,10 @@ fn check_termination(result: ExitStatus) { assert!(result.success(), "yes did not exit successfully"); } +const NO_ARGS: &[&str] = &[]; + /// Run `yes`, capture some of the output, close the pipe, and verify it. -fn run(args: &[&str], expected: &[u8]) { +fn run(args: &[impl AsRef], expected: &[u8]) { let mut cmd = new_ucmd!(); let mut child = cmd.args(args).set_stdout(Stdio::piped()).run_no_wait(); let buf = child.stdout_exact_bytes(expected.len()); @@ -34,7 +37,7 @@ fn test_invalid_arg() { #[test] fn test_simple() { - run(&[], b"y\ny\ny\ny\n"); + run(NO_ARGS, b"y\ny\ny\ny\n"); } #[test] @@ -44,7 +47,7 @@ fn test_args() { #[test] fn test_long_output() { - run(&[], "y\n".repeat(512 * 1024).as_bytes()); + run(NO_ARGS, "y\n".repeat(512 * 1024).as_bytes()); } /// Test with an output that seems likely to get mangled in case of incomplete writes. @@ -88,3 +91,20 @@ fn test_piped_to_dev_full() { } } } + +#[test] +#[cfg(any(unix, target_os = "wasi"))] +fn test_non_utf8() { + #[cfg(unix)] + use std::os::unix::ffi::OsStrExt; + #[cfg(target_os = "wasi")] + use std::os::wasi::ffi::OsStrExt; + + run( + &[ + OsStr::from_bytes(b"\xbf\xff\xee"), + OsStr::from_bytes(b"bar"), + ], + &b"\xbf\xff\xee bar\n".repeat(5000), + ); +}