diff --git a/src/uu/shuf/BENCHMARKING.md b/src/uu/shuf/BENCHMARKING.md index 58eefc499..d16b1afb0 100644 --- a/src/uu/shuf/BENCHMARKING.md +++ b/src/uu/shuf/BENCHMARKING.md @@ -28,11 +28,11 @@ a range of numbers to randomly sample from. An example of a command that works well for testing: ```shell -hyperfine --warmup 10 "target/release/shuf -i 0-10000000" +hyperfine --warmup 10 "target/release/shuf -i 0-10000000 > /dev/null" ``` To measure the time taken by shuffling an input file, the following command can -be used:: +be used: ```shell hyperfine --warmup 10 "target/release/shuf input.txt > /dev/null" @@ -49,5 +49,14 @@ should be benchmarked separately. In this case, we have to pass the `-n` flag or the command will run forever. An example of a hyperfine command is ```shell -hyperfine --warmup 10 "target/release/shuf -r -n 10000000 -i 0-1000" +hyperfine --warmup 10 "target/release/shuf -r -n 10000000 -i 0-1000 > /dev/null" +``` + +## With huge interval ranges + +When `shuf` runs with huge interval ranges, special care must be taken, so it +should be benchmarked separately also. An example of a hyperfine command is + +```shell +hyperfine --warmup 10 "target/release/shuf -n 100 -i 1000-2000000000 > /dev/null" ``` diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs index bab328e2e..a5456e184 100644 --- a/src/uu/shuf/src/shuf.rs +++ b/src/uu/shuf/src/shuf.rs @@ -3,14 +3,15 @@ // For the full copyright and license information, please view the LICENSE // file that was distributed with this source code. -// spell-checker:ignore (ToDO) cmdline evec seps rvec fdata +// spell-checker:ignore (ToDO) cmdline evec nonrepeating seps shufable rvec fdata use clap::{crate_version, Arg, ArgAction, Command}; use memchr::memchr_iter; use rand::prelude::SliceRandom; -use rand::RngCore; +use rand::{Rng, RngCore}; +use std::collections::HashSet; use std::fs::File; -use std::io::{stdin, stdout, BufReader, BufWriter, Read, Write}; +use std::io::{stdin, stdout, BufReader, BufWriter, Error, Read, Write}; use uucore::display::Quotable; use uucore::error::{FromIo, UResult, USimpleError, UUsageError}; use uucore::{format_usage, help_about, help_usage}; @@ -116,18 +117,16 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { Mode::Echo(args) => { let mut evec = args.iter().map(String::as_bytes).collect::>(); find_seps(&mut evec, options.sep); - shuf_bytes(&mut evec, options)?; + shuf_exec(&mut evec, options)?; } Mode::InputRange((b, e)) => { - let rvec = (b..e).map(|x| format!("{x}")).collect::>(); - let mut rvec = rvec.iter().map(String::as_bytes).collect::>(); - shuf_bytes(&mut rvec, options)?; + shuf_exec(&mut (b, e), options)?; } Mode::Default(filename) => { let fdata = read_input_file(&filename)?; let mut fdata = vec![&fdata[..]]; find_seps(&mut fdata, options.sep); - shuf_bytes(&mut fdata, options)?; + shuf_exec(&mut fdata, options)?; } } @@ -251,7 +250,173 @@ fn find_seps(data: &mut Vec<&[u8]>, sep: u8) { } } -fn shuf_bytes(input: &mut Vec<&[u8]>, opts: Options) -> UResult<()> { +trait Shufable { + type Item: Writable; + fn is_empty(&self) -> bool; + fn choose(&self, rng: &mut WrappedRng) -> Self::Item; + // This type shouldn't even be known. However, because we want to support + // Rust 1.70, it is not possible to return "impl Iterator". + // TODO: When the MSRV is raised, rewrite this to return "impl Iterator". + type PartialShuffleIterator<'b>: Iterator + where + Self: 'b; + fn partial_shuffle<'b>( + &'b mut self, + rng: &'b mut WrappedRng, + amount: usize, + ) -> Self::PartialShuffleIterator<'b>; +} + +impl<'a> Shufable for Vec<&'a [u8]> { + type Item = &'a [u8]; + fn is_empty(&self) -> bool { + (**self).is_empty() + } + fn choose(&self, rng: &mut WrappedRng) -> Self::Item { + // Note: "copied()" only copies the reference, not the entire [u8]. + // Returns None if the slice is empty. We checked this before, so + // this is safe. + (**self).choose(rng).unwrap() + } + type PartialShuffleIterator<'b> = std::iter::Copied> where Self: 'b; + fn partial_shuffle<'b>( + &'b mut self, + rng: &'b mut WrappedRng, + amount: usize, + ) -> Self::PartialShuffleIterator<'b> { + // Note: "copied()" only copies the reference, not the entire [u8]. + (**self).partial_shuffle(rng, amount).0.iter().copied() + } +} + +impl Shufable for (usize, usize) { + type Item = usize; + fn is_empty(&self) -> bool { + // Note: This is an inclusive range, so equality means there is 1 element. + self.0 > self.1 + } + fn choose(&self, rng: &mut WrappedRng) -> usize { + rng.gen_range(self.0..self.1) + } + type PartialShuffleIterator<'b> = NonrepeatingIterator<'b> where Self: 'b; + fn partial_shuffle<'b>( + &'b mut self, + rng: &'b mut WrappedRng, + amount: usize, + ) -> Self::PartialShuffleIterator<'b> { + NonrepeatingIterator::new(self.0, self.1, rng, amount) + } +} + +enum NumberSet { + AlreadyListed(HashSet), + Remaining(Vec), +} + +struct NonrepeatingIterator<'a> { + begin: usize, + end: usize, // exclusive + rng: &'a mut WrappedRng, + remaining_count: usize, + buf: NumberSet, +} + +impl<'a> NonrepeatingIterator<'a> { + fn new( + begin: usize, + end: usize, + rng: &'a mut WrappedRng, + amount: usize, + ) -> NonrepeatingIterator { + let capped_amount = if begin > end { + 0 + } else { + amount.min(end - begin) + }; + NonrepeatingIterator { + begin, + end, + rng, + remaining_count: capped_amount, + buf: NumberSet::AlreadyListed(HashSet::default()), + } + } + + fn produce(&mut self) -> usize { + debug_assert!(self.begin <= self.end); + match &mut self.buf { + NumberSet::AlreadyListed(already_listed) => { + let chosen = loop { + let guess = self.rng.gen_range(self.begin..self.end); + let newly_inserted = already_listed.insert(guess); + if newly_inserted { + break guess; + } + }; + // Once a significant fraction of the interval has already been enumerated, + // the number of attempts to find a number that hasn't been chosen yet increases. + // Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values". + let range_size = self.end - self.begin; + if number_set_should_list_remaining(already_listed.len(), range_size) { + let mut remaining = (self.begin..self.end) + .filter(|n| !already_listed.contains(n)) + .collect::>(); + assert!(remaining.len() >= self.remaining_count); + remaining.partial_shuffle(&mut self.rng, self.remaining_count); + remaining.truncate(self.remaining_count); + self.buf = NumberSet::Remaining(remaining); + } + chosen + } + NumberSet::Remaining(remaining_numbers) => { + debug_assert!(!remaining_numbers.is_empty()); + // We only enter produce() when there is at least one actual element remaining, so popping must always return an element. + remaining_numbers.pop().unwrap() + } + } + } +} + +impl<'a> Iterator for NonrepeatingIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + if self.begin > self.end || self.remaining_count == 0 { + return None; + } + self.remaining_count -= 1; + Some(self.produce()) + } +} + +// This could be a method, but it is much easier to test as a stand-alone function. +fn number_set_should_list_remaining(listed_count: usize, range_size: usize) -> bool { + // Arbitrarily determine the switchover point to be around 25%. This is because: + // - HashSet has a large space overhead for the hash table load factor. + // - This means that somewhere between 25-40%, the memory required for a "positive" HashSet and a "negative" Vec should be the same. + // - HashSet has a small but non-negligible overhead for each lookup, so we have a slight preference for Vec anyway. + // - At 25%, on average 1.33 attempts are needed to find a number that hasn't been taken yet. + // - Finally, "24%" is computationally the simplest: + listed_count >= range_size / 4 +} + +trait Writable { + fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error>; +} + +impl<'a> Writable for &'a [u8] { + fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error> { + output.write_all(self) + } +} + +impl Writable for usize { + fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error> { + output.write_all(format!("{self}").as_bytes()) + } +} + +fn shuf_exec(input: &mut impl Shufable, opts: Options) -> UResult<()> { let mut output = BufWriter::new(match opts.output { None => Box::new(stdout()) as Box, Some(s) => { @@ -276,22 +441,18 @@ fn shuf_bytes(input: &mut Vec<&[u8]>, opts: Options) -> UResult<()> { if opts.repeat { for _ in 0..opts.head_count { - // Returns None is the slice is empty. We checked this before, so - // this is safe. - let r = input.choose(&mut rng).unwrap(); + let r = input.choose(&mut rng); - output - .write_all(r) + r.write_all_to(&mut output) .map_err_context(|| "write failed".to_string())?; output .write_all(&[opts.sep]) .map_err_context(|| "write failed".to_string())?; } } else { - let (shuffled, _) = input.partial_shuffle(&mut rng, opts.head_count); + let shuffled = input.partial_shuffle(&mut rng, opts.head_count); for r in shuffled { - output - .write_all(r) + r.write_all_to(&mut output) .map_err_context(|| "write failed".to_string())?; output .write_all(&[opts.sep]) @@ -361,3 +522,88 @@ impl RngCore for WrappedRng { } } } + +#[cfg(test)] +// Since the computed value is a bool, it is more readable to write the expected value out: +#[allow(clippy::bool_assert_comparison)] +mod test_number_set_decision { + use super::number_set_should_list_remaining; + + #[test] + fn test_stay_positive_large_remaining_first() { + assert_eq!(false, number_set_should_list_remaining(0, std::usize::MAX)); + } + + #[test] + fn test_stay_positive_large_remaining_second() { + assert_eq!(false, number_set_should_list_remaining(1, std::usize::MAX)); + } + + #[test] + fn test_stay_positive_large_remaining_tenth() { + assert_eq!(false, number_set_should_list_remaining(9, std::usize::MAX)); + } + + #[test] + fn test_stay_positive_smallish_range_first() { + assert_eq!(false, number_set_should_list_remaining(0, 12345)); + } + + #[test] + fn test_stay_positive_smallish_range_second() { + assert_eq!(false, number_set_should_list_remaining(1, 12345)); + } + + #[test] + fn test_stay_positive_smallish_range_tenth() { + assert_eq!(false, number_set_should_list_remaining(9, 12345)); + } + + #[test] + fn test_stay_positive_small_range_not_too_early() { + assert_eq!(false, number_set_should_list_remaining(1, 10)); + } + + // Don't want to test close to the border, in case we decide to change the threshold. + // However, at 50% coverage, we absolutely should switch: + #[test] + fn test_switch_half() { + assert_eq!(true, number_set_should_list_remaining(1234, 2468)); + } + + // Ensure that the decision is monotonous: + #[test] + fn test_switch_late1() { + assert_eq!(true, number_set_should_list_remaining(12340, 12345)); + } + + #[test] + fn test_switch_late2() { + assert_eq!(true, number_set_should_list_remaining(12344, 12345)); + } + + // Ensure that we are overflow-free: + #[test] + fn test_no_crash_exceed_max_size1() { + assert_eq!( + false, + number_set_should_list_remaining(12345, std::usize::MAX) + ); + } + + #[test] + fn test_no_crash_exceed_max_size2() { + assert_eq!( + true, + number_set_should_list_remaining(std::usize::MAX - 1, std::usize::MAX) + ); + } + + #[test] + fn test_no_crash_exceed_max_size3() { + assert_eq!( + true, + number_set_should_list_remaining(std::usize::MAX, std::usize::MAX) + ); + } +} diff --git a/tests/by-util/test_shuf.rs b/tests/by-util/test_shuf.rs index c34c71e3b..76d9b3220 100644 --- a/tests/by-util/test_shuf.rs +++ b/tests/by-util/test_shuf.rs @@ -88,6 +88,99 @@ fn test_zero_termination_multi() { assert_eq!(result_seq, input_seq, "Output is not a permutation"); } +#[test] +fn test_very_large_range() { + let num_samples = 10; + let result = new_ucmd!() + .arg("-n") + .arg(&num_samples.to_string()) + .arg("-i0-1234567890") + .succeeds(); + result.no_stderr(); + + let result_seq: Vec = result + .stdout_str() + .split('\n') + .filter(|x| !x.is_empty()) + .map(|x| x.parse().unwrap()) + .collect(); + assert_eq!(result_seq.len(), num_samples, "Miscounted output length!"); + assert!( + result_seq.iter().all(|x| (0..=1234567890).contains(x)), + "Output includes element not from range: {}", + result.stdout_str() + ); +} + +#[test] +fn test_very_large_range_offset() { + let num_samples = 10; + let result = new_ucmd!() + .arg("-n") + .arg(&num_samples.to_string()) + .arg("-i1234567890-2147483647") + .succeeds(); + result.no_stderr(); + + let result_seq: Vec = result + .stdout_str() + .split('\n') + .filter(|x| !x.is_empty()) + .map(|x| x.parse().unwrap()) + .collect(); + assert_eq!(result_seq.len(), num_samples, "Miscounted output length!"); + assert!( + result_seq + .iter() + .all(|x| (1234567890..=2147483647).contains(x)), + "Output includes element not from range: {}", + result.stdout_str() + ); +} + +#[test] +fn test_very_high_range_full() { + let input_seq = vec![ + 2147483641, 2147483642, 2147483643, 2147483644, 2147483645, 2147483646, 2147483647, + ]; + let result = new_ucmd!().arg("-i2147483641-2147483647").succeeds(); + result.no_stderr(); + + let mut result_seq: Vec = result + .stdout_str() + .split('\n') + .filter(|x| !x.is_empty()) + .map(|x| x.parse().unwrap()) + .collect(); + result_seq.sort_unstable(); + assert_eq!(result_seq, input_seq, "Output is not a permutation"); +} + +#[test] +fn test_range_repeat() { + let num_samples = 500; + let result = new_ucmd!() + .arg("-r") + .arg("-n") + .arg(&num_samples.to_string()) + .arg("-i12-34") + .succeeds(); + result.no_stderr(); + + let result_seq: Vec = result + .stdout_str() + .split('\n') + .filter(|x| !x.is_empty()) + .map(|x| x.parse().unwrap()) + .collect(); + assert_eq!(result_seq.len(), num_samples, "Miscounted output length!"); + assert!( + result_seq.iter().all(|x| (12..=34).contains(x)), + "Output includes element not from range: {}", + result.stdout_str() + ); +} + #[test] fn test_empty_input() { let result = new_ucmd!().pipe_in(vec![]).succeeds();