1
Fork 0
mirror of https://github.com/RGBCube/uutils-coreutils synced 2025-07-27 19:17:43 +00:00

Merge pull request #5980 from BenWiederhake/dev-shuf-number-speed

shuf: Fix OOM crash for huge number ranges
This commit is contained in:
Sylvestre Ledru 2024-02-25 09:50:49 +01:00 committed by GitHub
commit 8301a8e5be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 368 additions and 20 deletions

View file

@ -28,11 +28,11 @@ a range of numbers to randomly sample from. An example of a command that works
well for testing: well for testing:
```shell ```shell
hyperfine --warmup 10 "target/release/shuf -i 0-10000000" hyperfine --warmup 10 "target/release/shuf -i 0-10000000 > /dev/null"
``` ```
To measure the time taken by shuffling an input file, the following command can To measure the time taken by shuffling an input file, the following command can
be used:: be used:
```shell ```shell
hyperfine --warmup 10 "target/release/shuf input.txt > /dev/null" hyperfine --warmup 10 "target/release/shuf input.txt > /dev/null"
@ -49,5 +49,14 @@ should be benchmarked separately. In this case, we have to pass the `-n` flag or
the command will run forever. An example of a hyperfine command is the command will run forever. An example of a hyperfine command is
```shell ```shell
hyperfine --warmup 10 "target/release/shuf -r -n 10000000 -i 0-1000" hyperfine --warmup 10 "target/release/shuf -r -n 10000000 -i 0-1000 > /dev/null"
```
## With huge interval ranges
When `shuf` runs with huge interval ranges, special care must be taken, so it
should be benchmarked separately also. An example of a hyperfine command is
```shell
hyperfine --warmup 10 "target/release/shuf -n 100 -i 1000-2000000000 > /dev/null"
``` ```

View file

@ -3,14 +3,15 @@
// For the full copyright and license information, please view the LICENSE // For the full copyright and license information, please view the LICENSE
// file that was distributed with this source code. // file that was distributed with this source code.
// spell-checker:ignore (ToDO) cmdline evec seps rvec fdata // spell-checker:ignore (ToDO) cmdline evec nonrepeating seps shufable rvec fdata
use clap::{crate_version, Arg, ArgAction, Command}; use clap::{crate_version, Arg, ArgAction, Command};
use memchr::memchr_iter; use memchr::memchr_iter;
use rand::prelude::SliceRandom; use rand::prelude::SliceRandom;
use rand::RngCore; use rand::{Rng, RngCore};
use std::collections::HashSet;
use std::fs::File; use std::fs::File;
use std::io::{stdin, stdout, BufReader, BufWriter, Read, Write}; use std::io::{stdin, stdout, BufReader, BufWriter, Error, Read, Write};
use uucore::display::Quotable; use uucore::display::Quotable;
use uucore::error::{FromIo, UResult, USimpleError, UUsageError}; use uucore::error::{FromIo, UResult, USimpleError, UUsageError};
use uucore::{format_usage, help_about, help_usage}; use uucore::{format_usage, help_about, help_usage};
@ -116,18 +117,16 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> {
Mode::Echo(args) => { Mode::Echo(args) => {
let mut evec = args.iter().map(String::as_bytes).collect::<Vec<_>>(); let mut evec = args.iter().map(String::as_bytes).collect::<Vec<_>>();
find_seps(&mut evec, options.sep); find_seps(&mut evec, options.sep);
shuf_bytes(&mut evec, options)?; shuf_exec(&mut evec, options)?;
} }
Mode::InputRange((b, e)) => { Mode::InputRange((b, e)) => {
let rvec = (b..e).map(|x| format!("{x}")).collect::<Vec<String>>(); shuf_exec(&mut (b, e), options)?;
let mut rvec = rvec.iter().map(String::as_bytes).collect::<Vec<&[u8]>>();
shuf_bytes(&mut rvec, options)?;
} }
Mode::Default(filename) => { Mode::Default(filename) => {
let fdata = read_input_file(&filename)?; let fdata = read_input_file(&filename)?;
let mut fdata = vec![&fdata[..]]; let mut fdata = vec![&fdata[..]];
find_seps(&mut fdata, options.sep); find_seps(&mut fdata, options.sep);
shuf_bytes(&mut fdata, options)?; shuf_exec(&mut fdata, options)?;
} }
} }
@ -251,7 +250,173 @@ fn find_seps(data: &mut Vec<&[u8]>, sep: u8) {
} }
} }
fn shuf_bytes(input: &mut Vec<&[u8]>, opts: Options) -> UResult<()> { trait Shufable {
type Item: Writable;
fn is_empty(&self) -> bool;
fn choose(&self, rng: &mut WrappedRng) -> Self::Item;
// This type shouldn't even be known. However, because we want to support
// Rust 1.70, it is not possible to return "impl Iterator".
// TODO: When the MSRV is raised, rewrite this to return "impl Iterator".
type PartialShuffleIterator<'b>: Iterator<Item = Self::Item>
where
Self: 'b;
fn partial_shuffle<'b>(
&'b mut self,
rng: &'b mut WrappedRng,
amount: usize,
) -> Self::PartialShuffleIterator<'b>;
}
impl<'a> Shufable for Vec<&'a [u8]> {
type Item = &'a [u8];
fn is_empty(&self) -> bool {
(**self).is_empty()
}
fn choose(&self, rng: &mut WrappedRng) -> Self::Item {
// Note: "copied()" only copies the reference, not the entire [u8].
// Returns None if the slice is empty. We checked this before, so
// this is safe.
(**self).choose(rng).unwrap()
}
type PartialShuffleIterator<'b> = std::iter::Copied<std::slice::Iter<'b, &'a [u8]>> where Self: 'b;
fn partial_shuffle<'b>(
&'b mut self,
rng: &'b mut WrappedRng,
amount: usize,
) -> Self::PartialShuffleIterator<'b> {
// Note: "copied()" only copies the reference, not the entire [u8].
(**self).partial_shuffle(rng, amount).0.iter().copied()
}
}
impl Shufable for (usize, usize) {
type Item = usize;
fn is_empty(&self) -> bool {
// Note: This is an inclusive range, so equality means there is 1 element.
self.0 > self.1
}
fn choose(&self, rng: &mut WrappedRng) -> usize {
rng.gen_range(self.0..self.1)
}
type PartialShuffleIterator<'b> = NonrepeatingIterator<'b> where Self: 'b;
fn partial_shuffle<'b>(
&'b mut self,
rng: &'b mut WrappedRng,
amount: usize,
) -> Self::PartialShuffleIterator<'b> {
NonrepeatingIterator::new(self.0, self.1, rng, amount)
}
}
enum NumberSet {
AlreadyListed(HashSet<usize>),
Remaining(Vec<usize>),
}
struct NonrepeatingIterator<'a> {
begin: usize,
end: usize, // exclusive
rng: &'a mut WrappedRng,
remaining_count: usize,
buf: NumberSet,
}
impl<'a> NonrepeatingIterator<'a> {
fn new(
begin: usize,
end: usize,
rng: &'a mut WrappedRng,
amount: usize,
) -> NonrepeatingIterator {
let capped_amount = if begin > end {
0
} else {
amount.min(end - begin)
};
NonrepeatingIterator {
begin,
end,
rng,
remaining_count: capped_amount,
buf: NumberSet::AlreadyListed(HashSet::default()),
}
}
fn produce(&mut self) -> usize {
debug_assert!(self.begin <= self.end);
match &mut self.buf {
NumberSet::AlreadyListed(already_listed) => {
let chosen = loop {
let guess = self.rng.gen_range(self.begin..self.end);
let newly_inserted = already_listed.insert(guess);
if newly_inserted {
break guess;
}
};
// Once a significant fraction of the interval has already been enumerated,
// the number of attempts to find a number that hasn't been chosen yet increases.
// Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values".
let range_size = self.end - self.begin;
if number_set_should_list_remaining(already_listed.len(), range_size) {
let mut remaining = (self.begin..self.end)
.filter(|n| !already_listed.contains(n))
.collect::<Vec<_>>();
assert!(remaining.len() >= self.remaining_count);
remaining.partial_shuffle(&mut self.rng, self.remaining_count);
remaining.truncate(self.remaining_count);
self.buf = NumberSet::Remaining(remaining);
}
chosen
}
NumberSet::Remaining(remaining_numbers) => {
debug_assert!(!remaining_numbers.is_empty());
// We only enter produce() when there is at least one actual element remaining, so popping must always return an element.
remaining_numbers.pop().unwrap()
}
}
}
}
impl<'a> Iterator for NonrepeatingIterator<'a> {
type Item = usize;
fn next(&mut self) -> Option<usize> {
if self.begin > self.end || self.remaining_count == 0 {
return None;
}
self.remaining_count -= 1;
Some(self.produce())
}
}
// This could be a method, but it is much easier to test as a stand-alone function.
fn number_set_should_list_remaining(listed_count: usize, range_size: usize) -> bool {
// Arbitrarily determine the switchover point to be around 25%. This is because:
// - HashSet has a large space overhead for the hash table load factor.
// - This means that somewhere between 25-40%, the memory required for a "positive" HashSet and a "negative" Vec should be the same.
// - HashSet has a small but non-negligible overhead for each lookup, so we have a slight preference for Vec anyway.
// - At 25%, on average 1.33 attempts are needed to find a number that hasn't been taken yet.
// - Finally, "24%" is computationally the simplest:
listed_count >= range_size / 4
}
trait Writable {
fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error>;
}
impl<'a> Writable for &'a [u8] {
fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error> {
output.write_all(self)
}
}
impl Writable for usize {
fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error> {
output.write_all(format!("{self}").as_bytes())
}
}
fn shuf_exec(input: &mut impl Shufable, opts: Options) -> UResult<()> {
let mut output = BufWriter::new(match opts.output { let mut output = BufWriter::new(match opts.output {
None => Box::new(stdout()) as Box<dyn Write>, None => Box::new(stdout()) as Box<dyn Write>,
Some(s) => { Some(s) => {
@ -276,22 +441,18 @@ fn shuf_bytes(input: &mut Vec<&[u8]>, opts: Options) -> UResult<()> {
if opts.repeat { if opts.repeat {
for _ in 0..opts.head_count { for _ in 0..opts.head_count {
// Returns None is the slice is empty. We checked this before, so let r = input.choose(&mut rng);
// this is safe.
let r = input.choose(&mut rng).unwrap();
output r.write_all_to(&mut output)
.write_all(r)
.map_err_context(|| "write failed".to_string())?; .map_err_context(|| "write failed".to_string())?;
output output
.write_all(&[opts.sep]) .write_all(&[opts.sep])
.map_err_context(|| "write failed".to_string())?; .map_err_context(|| "write failed".to_string())?;
} }
} else { } else {
let (shuffled, _) = input.partial_shuffle(&mut rng, opts.head_count); let shuffled = input.partial_shuffle(&mut rng, opts.head_count);
for r in shuffled { for r in shuffled {
output r.write_all_to(&mut output)
.write_all(r)
.map_err_context(|| "write failed".to_string())?; .map_err_context(|| "write failed".to_string())?;
output output
.write_all(&[opts.sep]) .write_all(&[opts.sep])
@ -361,3 +522,88 @@ impl RngCore for WrappedRng {
} }
} }
} }
#[cfg(test)]
// Since the computed value is a bool, it is more readable to write the expected value out:
#[allow(clippy::bool_assert_comparison)]
mod test_number_set_decision {
use super::number_set_should_list_remaining;
#[test]
fn test_stay_positive_large_remaining_first() {
assert_eq!(false, number_set_should_list_remaining(0, std::usize::MAX));
}
#[test]
fn test_stay_positive_large_remaining_second() {
assert_eq!(false, number_set_should_list_remaining(1, std::usize::MAX));
}
#[test]
fn test_stay_positive_large_remaining_tenth() {
assert_eq!(false, number_set_should_list_remaining(9, std::usize::MAX));
}
#[test]
fn test_stay_positive_smallish_range_first() {
assert_eq!(false, number_set_should_list_remaining(0, 12345));
}
#[test]
fn test_stay_positive_smallish_range_second() {
assert_eq!(false, number_set_should_list_remaining(1, 12345));
}
#[test]
fn test_stay_positive_smallish_range_tenth() {
assert_eq!(false, number_set_should_list_remaining(9, 12345));
}
#[test]
fn test_stay_positive_small_range_not_too_early() {
assert_eq!(false, number_set_should_list_remaining(1, 10));
}
// Don't want to test close to the border, in case we decide to change the threshold.
// However, at 50% coverage, we absolutely should switch:
#[test]
fn test_switch_half() {
assert_eq!(true, number_set_should_list_remaining(1234, 2468));
}
// Ensure that the decision is monotonous:
#[test]
fn test_switch_late1() {
assert_eq!(true, number_set_should_list_remaining(12340, 12345));
}
#[test]
fn test_switch_late2() {
assert_eq!(true, number_set_should_list_remaining(12344, 12345));
}
// Ensure that we are overflow-free:
#[test]
fn test_no_crash_exceed_max_size1() {
assert_eq!(
false,
number_set_should_list_remaining(12345, std::usize::MAX)
);
}
#[test]
fn test_no_crash_exceed_max_size2() {
assert_eq!(
true,
number_set_should_list_remaining(std::usize::MAX - 1, std::usize::MAX)
);
}
#[test]
fn test_no_crash_exceed_max_size3() {
assert_eq!(
true,
number_set_should_list_remaining(std::usize::MAX, std::usize::MAX)
);
}
}

View file

@ -88,6 +88,99 @@ fn test_zero_termination_multi() {
assert_eq!(result_seq, input_seq, "Output is not a permutation"); assert_eq!(result_seq, input_seq, "Output is not a permutation");
} }
#[test]
fn test_very_large_range() {
let num_samples = 10;
let result = new_ucmd!()
.arg("-n")
.arg(&num_samples.to_string())
.arg("-i0-1234567890")
.succeeds();
result.no_stderr();
let result_seq: Vec<isize> = result
.stdout_str()
.split('\n')
.filter(|x| !x.is_empty())
.map(|x| x.parse().unwrap())
.collect();
assert_eq!(result_seq.len(), num_samples, "Miscounted output length!");
assert!(
result_seq.iter().all(|x| (0..=1234567890).contains(x)),
"Output includes element not from range: {}",
result.stdout_str()
);
}
#[test]
fn test_very_large_range_offset() {
let num_samples = 10;
let result = new_ucmd!()
.arg("-n")
.arg(&num_samples.to_string())
.arg("-i1234567890-2147483647")
.succeeds();
result.no_stderr();
let result_seq: Vec<isize> = result
.stdout_str()
.split('\n')
.filter(|x| !x.is_empty())
.map(|x| x.parse().unwrap())
.collect();
assert_eq!(result_seq.len(), num_samples, "Miscounted output length!");
assert!(
result_seq
.iter()
.all(|x| (1234567890..=2147483647).contains(x)),
"Output includes element not from range: {}",
result.stdout_str()
);
}
#[test]
fn test_very_high_range_full() {
let input_seq = vec![
2147483641, 2147483642, 2147483643, 2147483644, 2147483645, 2147483646, 2147483647,
];
let result = new_ucmd!().arg("-i2147483641-2147483647").succeeds();
result.no_stderr();
let mut result_seq: Vec<isize> = result
.stdout_str()
.split('\n')
.filter(|x| !x.is_empty())
.map(|x| x.parse().unwrap())
.collect();
result_seq.sort_unstable();
assert_eq!(result_seq, input_seq, "Output is not a permutation");
}
#[test]
fn test_range_repeat() {
let num_samples = 500;
let result = new_ucmd!()
.arg("-r")
.arg("-n")
.arg(&num_samples.to_string())
.arg("-i12-34")
.succeeds();
result.no_stderr();
let result_seq: Vec<isize> = result
.stdout_str()
.split('\n')
.filter(|x| !x.is_empty())
.map(|x| x.parse().unwrap())
.collect();
assert_eq!(result_seq.len(), num_samples, "Miscounted output length!");
assert!(
result_seq.iter().all(|x| (12..=34).contains(x)),
"Output includes element not from range: {}",
result.stdout_str()
);
}
#[test] #[test]
fn test_empty_input() { fn test_empty_input() {
let result = new_ucmd!().pipe_in(vec![]).succeeds(); let result = new_ucmd!().pipe_in(vec![]).succeeds();