diff --git a/Cargo.lock b/Cargo.lock index 5f8bddf94..6abb5014a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3379,6 +3379,7 @@ name = "uu_yes" version = "0.0.18" dependencies = [ "clap", + "itertools", "nix", "uucore", ] diff --git a/src/uu/yes/Cargo.toml b/src/uu/yes/Cargo.toml index fd3d9ddc0..9d661fb0d 100644 --- a/src/uu/yes/Cargo.toml +++ b/src/uu/yes/Cargo.toml @@ -16,6 +16,7 @@ path = "src/yes.rs" [dependencies] clap = { workspace=true } +itertools = { workspace=true } [target.'cfg(unix)'.dependencies] uucore = { workspace=true, features=["pipes", "signals"] } diff --git a/src/uu/yes/src/yes.rs b/src/uu/yes/src/yes.rs index 41bfeddca..fd5124064 100644 --- a/src/uu/yes/src/yes.rs +++ b/src/uu/yes/src/yes.rs @@ -7,8 +7,11 @@ /* last synced with: yes (GNU coreutils) 8.13 */ -use clap::{Arg, ArgAction, Command}; -use std::borrow::Cow; +// cSpell:ignore strs + +use clap::{builder::ValueParser, Arg, ArgAction, Command}; +use std::error::Error; +use std::ffi::OsString; use std::io::{self, Write}; use uucore::error::{UResult, USimpleError}; #[cfg(unix)] @@ -28,19 +31,11 @@ const BUF_SIZE: usize = 16 * 1024; pub fn uumain(args: impl uucore::Args) -> UResult<()> { let matches = uu_app().try_get_matches_from(args)?; - let string = if let Some(values) = matches.get_many::("STRING") { - let mut result = values.fold(String::new(), |res, s| res + s + " "); - result.pop(); - result.push('\n'); - Cow::from(result) - } else { - Cow::from("y\n") - }; + let mut buffer = Vec::with_capacity(BUF_SIZE); + args_into_buffer(&mut buffer, matches.get_many::("STRING")).unwrap(); + prepare_buffer(&mut buffer); - let mut buffer = [0; BUF_SIZE]; - let bytes = prepare_buffer(&string, &mut buffer); - - match exec(bytes) { + match exec(&buffer) { Ok(()) => Ok(()), Err(err) if err.kind() == io::ErrorKind::BrokenPipe => Ok(()), Err(err) => Err(USimpleError::new(1, format!("standard output: {err}"))), @@ -51,21 +46,73 @@ pub fn uu_app() -> Command { Command::new(uucore::util_name()) .about(ABOUT) .override_usage(format_usage(USAGE)) - .arg(Arg::new("STRING").action(ArgAction::Append)) + .arg( + Arg::new("STRING") + .value_parser(ValueParser::os_string()) + .action(ArgAction::Append), + ) .infer_long_args(true) } -fn prepare_buffer<'a>(input: &'a str, buffer: &'a mut [u8; BUF_SIZE]) -> &'a [u8] { - if input.len() < BUF_SIZE / 2 { - let mut size = 0; - while size < BUF_SIZE - input.len() { - let (_, right) = buffer.split_at_mut(size); - right[..input.len()].copy_from_slice(input.as_bytes()); - size += input.len(); - } - &buffer[..size] +// Copies words from `i` into `buf`, separated by spaces. +fn args_into_buffer<'a>( + buf: &mut Vec, + i: Option>, +) -> Result<(), Box> { + // TODO: this should be replaced with let/else once available in the MSRV. + let i = if let Some(i) = i { + i } else { - input.as_bytes() + buf.extend_from_slice(b"y\n"); + return Ok(()); + }; + + // On Unix (and wasi), OsStrs are just &[u8]'s underneath... + #[cfg(any(unix, target_os = "wasi"))] + { + #[cfg(unix)] + use std::os::unix::ffi::OsStrExt; + #[cfg(target_os = "wasi")] + use std::os::wasi::ffi::OsStrExt; + + for part in itertools::intersperse(i.map(|a| a.as_bytes()), b" ") { + buf.extend_from_slice(part); + } + } + + // But, on Windows, we must hop through a String. + #[cfg(not(any(unix, target_os = "wasi")))] + { + for part in itertools::intersperse(i.map(|a| a.to_str()), Some(" ")) { + let bytes = match part { + Some(part) => part.as_bytes(), + None => return Err("arguments contain invalid UTF-8".into()), + }; + buf.extend_from_slice(bytes); + } + } + + buf.push(b'\n'); + + Ok(()) +} + +// Assumes buf holds a single output line forged from the command line arguments, copies it +// repeatedly until the buffer holds as many copies as it can under BUF_SIZE. +fn prepare_buffer(buf: &mut Vec) { + if buf.len() * 2 > BUF_SIZE { + return; + } + + assert!(!buf.is_empty()); + + let line_len = buf.len(); + let target_size = line_len * (BUF_SIZE / line_len); + + while buf.len() < target_size { + let to_copy = std::cmp::min(target_size - buf.len(), buf.len()); + debug_assert_eq!(to_copy % line_len, 0); + buf.extend_from_within(..to_copy); } } @@ -88,3 +135,67 @@ pub fn exec(bytes: &[u8]) -> io::Result<()> { stdout.write_all(bytes)?; } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_prepare_buffer() { + let tests = [ + (150, 16350), + (1000, 16000), + (4093, 16372), + (4099, 12297), + (4111, 12333), + (2, 16384), + (3, 16383), + (4, 16384), + (5, 16380), + (8192, 16384), + (8191, 16382), + (8193, 8193), + (10000, 10000), + (15000, 15000), + (25000, 25000), + ]; + + for (line, final_len) in tests { + let mut v = std::iter::repeat(b'a').take(line).collect::>(); + prepare_buffer(&mut v); + assert_eq!(v.len(), final_len); + } + } + + #[test] + fn test_args_into_buf() { + { + let mut v = Vec::with_capacity(BUF_SIZE); + args_into_buffer(&mut v, None::>).unwrap(); + assert_eq!(String::from_utf8(v).unwrap(), "y\n"); + } + + { + let mut v = Vec::with_capacity(BUF_SIZE); + args_into_buffer(&mut v, Some([OsString::from("foo")].iter())).unwrap(); + assert_eq!(String::from_utf8(v).unwrap(), "foo\n"); + } + + { + let mut v = Vec::with_capacity(BUF_SIZE); + args_into_buffer( + &mut v, + Some( + [ + OsString::from("foo"), + OsString::from("bar baz"), + OsString::from("qux"), + ] + .iter(), + ), + ) + .unwrap(); + assert_eq!(String::from_utf8(v).unwrap(), "foo bar baz qux\n"); + } + } +} diff --git a/tests/by-util/test_yes.rs b/tests/by-util/test_yes.rs index c054a6e5f..89a68e7e1 100644 --- a/tests/by-util/test_yes.rs +++ b/tests/by-util/test_yes.rs @@ -1,3 +1,4 @@ +use std::ffi::OsStr; use std::process::{ExitStatus, Stdio}; #[cfg(unix)] @@ -15,8 +16,10 @@ fn check_termination(result: ExitStatus) { assert!(result.success(), "yes did not exit successfully"); } +const NO_ARGS: &[&str] = &[]; + /// Run `yes`, capture some of the output, close the pipe, and verify it. -fn run(args: &[&str], expected: &[u8]) { +fn run(args: &[impl AsRef], expected: &[u8]) { let mut cmd = new_ucmd!(); let mut child = cmd.args(args).set_stdout(Stdio::piped()).run_no_wait(); let buf = child.stdout_exact_bytes(expected.len()); @@ -34,7 +37,7 @@ fn test_invalid_arg() { #[test] fn test_simple() { - run(&[], b"y\ny\ny\ny\n"); + run(NO_ARGS, b"y\ny\ny\ny\n"); } #[test] @@ -44,7 +47,7 @@ fn test_args() { #[test] fn test_long_output() { - run(&[], "y\n".repeat(512 * 1024).as_bytes()); + run(NO_ARGS, "y\n".repeat(512 * 1024).as_bytes()); } /// Test with an output that seems likely to get mangled in case of incomplete writes. @@ -88,3 +91,20 @@ fn test_piped_to_dev_full() { } } } + +#[test] +#[cfg(any(unix, target_os = "wasi"))] +fn test_non_utf8() { + #[cfg(unix)] + use std::os::unix::ffi::OsStrExt; + #[cfg(target_os = "wasi")] + use std::os::wasi::ffi::OsStrExt; + + run( + &[ + OsStr::from_bytes(b"\xbf\xff\xee"), + OsStr::from_bytes(b"bar"), + ], + &b"\xbf\xff\xee bar\n".repeat(5000), + ); +}