mirror of
https://github.com/RGBCube/uutils-coreutils
synced 2025-07-27 11:07:44 +00:00
Merge pull request #5980 from BenWiederhake/dev-shuf-number-speed
shuf: Fix OOM crash for huge number ranges
This commit is contained in:
commit
8301a8e5be
3 changed files with 368 additions and 20 deletions
|
@ -28,11 +28,11 @@ a range of numbers to randomly sample from. An example of a command that works
|
|||
well for testing:
|
||||
|
||||
```shell
|
||||
hyperfine --warmup 10 "target/release/shuf -i 0-10000000"
|
||||
hyperfine --warmup 10 "target/release/shuf -i 0-10000000 > /dev/null"
|
||||
```
|
||||
|
||||
To measure the time taken by shuffling an input file, the following command can
|
||||
be used::
|
||||
be used:
|
||||
|
||||
```shell
|
||||
hyperfine --warmup 10 "target/release/shuf input.txt > /dev/null"
|
||||
|
@ -49,5 +49,14 @@ should be benchmarked separately. In this case, we have to pass the `-n` flag or
|
|||
the command will run forever. An example of a hyperfine command is
|
||||
|
||||
```shell
|
||||
hyperfine --warmup 10 "target/release/shuf -r -n 10000000 -i 0-1000"
|
||||
hyperfine --warmup 10 "target/release/shuf -r -n 10000000 -i 0-1000 > /dev/null"
|
||||
```
|
||||
|
||||
## With huge interval ranges
|
||||
|
||||
When `shuf` runs with huge interval ranges, special care must be taken, so it
|
||||
should be benchmarked separately also. An example of a hyperfine command is
|
||||
|
||||
```shell
|
||||
hyperfine --warmup 10 "target/release/shuf -n 100 -i 1000-2000000000 > /dev/null"
|
||||
```
|
||||
|
|
|
@ -3,14 +3,15 @@
|
|||
// For the full copyright and license information, please view the LICENSE
|
||||
// file that was distributed with this source code.
|
||||
|
||||
// spell-checker:ignore (ToDO) cmdline evec seps rvec fdata
|
||||
// spell-checker:ignore (ToDO) cmdline evec nonrepeating seps shufable rvec fdata
|
||||
|
||||
use clap::{crate_version, Arg, ArgAction, Command};
|
||||
use memchr::memchr_iter;
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::RngCore;
|
||||
use rand::{Rng, RngCore};
|
||||
use std::collections::HashSet;
|
||||
use std::fs::File;
|
||||
use std::io::{stdin, stdout, BufReader, BufWriter, Read, Write};
|
||||
use std::io::{stdin, stdout, BufReader, BufWriter, Error, Read, Write};
|
||||
use uucore::display::Quotable;
|
||||
use uucore::error::{FromIo, UResult, USimpleError, UUsageError};
|
||||
use uucore::{format_usage, help_about, help_usage};
|
||||
|
@ -116,18 +117,16 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> {
|
|||
Mode::Echo(args) => {
|
||||
let mut evec = args.iter().map(String::as_bytes).collect::<Vec<_>>();
|
||||
find_seps(&mut evec, options.sep);
|
||||
shuf_bytes(&mut evec, options)?;
|
||||
shuf_exec(&mut evec, options)?;
|
||||
}
|
||||
Mode::InputRange((b, e)) => {
|
||||
let rvec = (b..e).map(|x| format!("{x}")).collect::<Vec<String>>();
|
||||
let mut rvec = rvec.iter().map(String::as_bytes).collect::<Vec<&[u8]>>();
|
||||
shuf_bytes(&mut rvec, options)?;
|
||||
shuf_exec(&mut (b, e), options)?;
|
||||
}
|
||||
Mode::Default(filename) => {
|
||||
let fdata = read_input_file(&filename)?;
|
||||
let mut fdata = vec![&fdata[..]];
|
||||
find_seps(&mut fdata, options.sep);
|
||||
shuf_bytes(&mut fdata, options)?;
|
||||
shuf_exec(&mut fdata, options)?;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -251,7 +250,173 @@ fn find_seps(data: &mut Vec<&[u8]>, sep: u8) {
|
|||
}
|
||||
}
|
||||
|
||||
fn shuf_bytes(input: &mut Vec<&[u8]>, opts: Options) -> UResult<()> {
|
||||
trait Shufable {
|
||||
type Item: Writable;
|
||||
fn is_empty(&self) -> bool;
|
||||
fn choose(&self, rng: &mut WrappedRng) -> Self::Item;
|
||||
// This type shouldn't even be known. However, because we want to support
|
||||
// Rust 1.70, it is not possible to return "impl Iterator".
|
||||
// TODO: When the MSRV is raised, rewrite this to return "impl Iterator".
|
||||
type PartialShuffleIterator<'b>: Iterator<Item = Self::Item>
|
||||
where
|
||||
Self: 'b;
|
||||
fn partial_shuffle<'b>(
|
||||
&'b mut self,
|
||||
rng: &'b mut WrappedRng,
|
||||
amount: usize,
|
||||
) -> Self::PartialShuffleIterator<'b>;
|
||||
}
|
||||
|
||||
impl<'a> Shufable for Vec<&'a [u8]> {
|
||||
type Item = &'a [u8];
|
||||
fn is_empty(&self) -> bool {
|
||||
(**self).is_empty()
|
||||
}
|
||||
fn choose(&self, rng: &mut WrappedRng) -> Self::Item {
|
||||
// Note: "copied()" only copies the reference, not the entire [u8].
|
||||
// Returns None if the slice is empty. We checked this before, so
|
||||
// this is safe.
|
||||
(**self).choose(rng).unwrap()
|
||||
}
|
||||
type PartialShuffleIterator<'b> = std::iter::Copied<std::slice::Iter<'b, &'a [u8]>> where Self: 'b;
|
||||
fn partial_shuffle<'b>(
|
||||
&'b mut self,
|
||||
rng: &'b mut WrappedRng,
|
||||
amount: usize,
|
||||
) -> Self::PartialShuffleIterator<'b> {
|
||||
// Note: "copied()" only copies the reference, not the entire [u8].
|
||||
(**self).partial_shuffle(rng, amount).0.iter().copied()
|
||||
}
|
||||
}
|
||||
|
||||
impl Shufable for (usize, usize) {
|
||||
type Item = usize;
|
||||
fn is_empty(&self) -> bool {
|
||||
// Note: This is an inclusive range, so equality means there is 1 element.
|
||||
self.0 > self.1
|
||||
}
|
||||
fn choose(&self, rng: &mut WrappedRng) -> usize {
|
||||
rng.gen_range(self.0..self.1)
|
||||
}
|
||||
type PartialShuffleIterator<'b> = NonrepeatingIterator<'b> where Self: 'b;
|
||||
fn partial_shuffle<'b>(
|
||||
&'b mut self,
|
||||
rng: &'b mut WrappedRng,
|
||||
amount: usize,
|
||||
) -> Self::PartialShuffleIterator<'b> {
|
||||
NonrepeatingIterator::new(self.0, self.1, rng, amount)
|
||||
}
|
||||
}
|
||||
|
||||
enum NumberSet {
|
||||
AlreadyListed(HashSet<usize>),
|
||||
Remaining(Vec<usize>),
|
||||
}
|
||||
|
||||
struct NonrepeatingIterator<'a> {
|
||||
begin: usize,
|
||||
end: usize, // exclusive
|
||||
rng: &'a mut WrappedRng,
|
||||
remaining_count: usize,
|
||||
buf: NumberSet,
|
||||
}
|
||||
|
||||
impl<'a> NonrepeatingIterator<'a> {
|
||||
fn new(
|
||||
begin: usize,
|
||||
end: usize,
|
||||
rng: &'a mut WrappedRng,
|
||||
amount: usize,
|
||||
) -> NonrepeatingIterator {
|
||||
let capped_amount = if begin > end {
|
||||
0
|
||||
} else {
|
||||
amount.min(end - begin)
|
||||
};
|
||||
NonrepeatingIterator {
|
||||
begin,
|
||||
end,
|
||||
rng,
|
||||
remaining_count: capped_amount,
|
||||
buf: NumberSet::AlreadyListed(HashSet::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn produce(&mut self) -> usize {
|
||||
debug_assert!(self.begin <= self.end);
|
||||
match &mut self.buf {
|
||||
NumberSet::AlreadyListed(already_listed) => {
|
||||
let chosen = loop {
|
||||
let guess = self.rng.gen_range(self.begin..self.end);
|
||||
let newly_inserted = already_listed.insert(guess);
|
||||
if newly_inserted {
|
||||
break guess;
|
||||
}
|
||||
};
|
||||
// Once a significant fraction of the interval has already been enumerated,
|
||||
// the number of attempts to find a number that hasn't been chosen yet increases.
|
||||
// Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values".
|
||||
let range_size = self.end - self.begin;
|
||||
if number_set_should_list_remaining(already_listed.len(), range_size) {
|
||||
let mut remaining = (self.begin..self.end)
|
||||
.filter(|n| !already_listed.contains(n))
|
||||
.collect::<Vec<_>>();
|
||||
assert!(remaining.len() >= self.remaining_count);
|
||||
remaining.partial_shuffle(&mut self.rng, self.remaining_count);
|
||||
remaining.truncate(self.remaining_count);
|
||||
self.buf = NumberSet::Remaining(remaining);
|
||||
}
|
||||
chosen
|
||||
}
|
||||
NumberSet::Remaining(remaining_numbers) => {
|
||||
debug_assert!(!remaining_numbers.is_empty());
|
||||
// We only enter produce() when there is at least one actual element remaining, so popping must always return an element.
|
||||
remaining_numbers.pop().unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for NonrepeatingIterator<'a> {
|
||||
type Item = usize;
|
||||
|
||||
fn next(&mut self) -> Option<usize> {
|
||||
if self.begin > self.end || self.remaining_count == 0 {
|
||||
return None;
|
||||
}
|
||||
self.remaining_count -= 1;
|
||||
Some(self.produce())
|
||||
}
|
||||
}
|
||||
|
||||
// This could be a method, but it is much easier to test as a stand-alone function.
|
||||
fn number_set_should_list_remaining(listed_count: usize, range_size: usize) -> bool {
|
||||
// Arbitrarily determine the switchover point to be around 25%. This is because:
|
||||
// - HashSet has a large space overhead for the hash table load factor.
|
||||
// - This means that somewhere between 25-40%, the memory required for a "positive" HashSet and a "negative" Vec should be the same.
|
||||
// - HashSet has a small but non-negligible overhead for each lookup, so we have a slight preference for Vec anyway.
|
||||
// - At 25%, on average 1.33 attempts are needed to find a number that hasn't been taken yet.
|
||||
// - Finally, "24%" is computationally the simplest:
|
||||
listed_count >= range_size / 4
|
||||
}
|
||||
|
||||
trait Writable {
|
||||
fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
impl<'a> Writable for &'a [u8] {
|
||||
fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error> {
|
||||
output.write_all(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Writable for usize {
|
||||
fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error> {
|
||||
output.write_all(format!("{self}").as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
fn shuf_exec(input: &mut impl Shufable, opts: Options) -> UResult<()> {
|
||||
let mut output = BufWriter::new(match opts.output {
|
||||
None => Box::new(stdout()) as Box<dyn Write>,
|
||||
Some(s) => {
|
||||
|
@ -276,22 +441,18 @@ fn shuf_bytes(input: &mut Vec<&[u8]>, opts: Options) -> UResult<()> {
|
|||
|
||||
if opts.repeat {
|
||||
for _ in 0..opts.head_count {
|
||||
// Returns None is the slice is empty. We checked this before, so
|
||||
// this is safe.
|
||||
let r = input.choose(&mut rng).unwrap();
|
||||
let r = input.choose(&mut rng);
|
||||
|
||||
output
|
||||
.write_all(r)
|
||||
r.write_all_to(&mut output)
|
||||
.map_err_context(|| "write failed".to_string())?;
|
||||
output
|
||||
.write_all(&[opts.sep])
|
||||
.map_err_context(|| "write failed".to_string())?;
|
||||
}
|
||||
} else {
|
||||
let (shuffled, _) = input.partial_shuffle(&mut rng, opts.head_count);
|
||||
let shuffled = input.partial_shuffle(&mut rng, opts.head_count);
|
||||
for r in shuffled {
|
||||
output
|
||||
.write_all(r)
|
||||
r.write_all_to(&mut output)
|
||||
.map_err_context(|| "write failed".to_string())?;
|
||||
output
|
||||
.write_all(&[opts.sep])
|
||||
|
@ -361,3 +522,88 @@ impl RngCore for WrappedRng {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
// Since the computed value is a bool, it is more readable to write the expected value out:
|
||||
#[allow(clippy::bool_assert_comparison)]
|
||||
mod test_number_set_decision {
|
||||
use super::number_set_should_list_remaining;
|
||||
|
||||
#[test]
|
||||
fn test_stay_positive_large_remaining_first() {
|
||||
assert_eq!(false, number_set_should_list_remaining(0, std::usize::MAX));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stay_positive_large_remaining_second() {
|
||||
assert_eq!(false, number_set_should_list_remaining(1, std::usize::MAX));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stay_positive_large_remaining_tenth() {
|
||||
assert_eq!(false, number_set_should_list_remaining(9, std::usize::MAX));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stay_positive_smallish_range_first() {
|
||||
assert_eq!(false, number_set_should_list_remaining(0, 12345));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stay_positive_smallish_range_second() {
|
||||
assert_eq!(false, number_set_should_list_remaining(1, 12345));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stay_positive_smallish_range_tenth() {
|
||||
assert_eq!(false, number_set_should_list_remaining(9, 12345));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stay_positive_small_range_not_too_early() {
|
||||
assert_eq!(false, number_set_should_list_remaining(1, 10));
|
||||
}
|
||||
|
||||
// Don't want to test close to the border, in case we decide to change the threshold.
|
||||
// However, at 50% coverage, we absolutely should switch:
|
||||
#[test]
|
||||
fn test_switch_half() {
|
||||
assert_eq!(true, number_set_should_list_remaining(1234, 2468));
|
||||
}
|
||||
|
||||
// Ensure that the decision is monotonous:
|
||||
#[test]
|
||||
fn test_switch_late1() {
|
||||
assert_eq!(true, number_set_should_list_remaining(12340, 12345));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_switch_late2() {
|
||||
assert_eq!(true, number_set_should_list_remaining(12344, 12345));
|
||||
}
|
||||
|
||||
// Ensure that we are overflow-free:
|
||||
#[test]
|
||||
fn test_no_crash_exceed_max_size1() {
|
||||
assert_eq!(
|
||||
false,
|
||||
number_set_should_list_remaining(12345, std::usize::MAX)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_crash_exceed_max_size2() {
|
||||
assert_eq!(
|
||||
true,
|
||||
number_set_should_list_remaining(std::usize::MAX - 1, std::usize::MAX)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_crash_exceed_max_size3() {
|
||||
assert_eq!(
|
||||
true,
|
||||
number_set_should_list_remaining(std::usize::MAX, std::usize::MAX)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,6 +88,99 @@ fn test_zero_termination_multi() {
|
|||
assert_eq!(result_seq, input_seq, "Output is not a permutation");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_very_large_range() {
|
||||
let num_samples = 10;
|
||||
let result = new_ucmd!()
|
||||
.arg("-n")
|
||||
.arg(&num_samples.to_string())
|
||||
.arg("-i0-1234567890")
|
||||
.succeeds();
|
||||
result.no_stderr();
|
||||
|
||||
let result_seq: Vec<isize> = result
|
||||
.stdout_str()
|
||||
.split('\n')
|
||||
.filter(|x| !x.is_empty())
|
||||
.map(|x| x.parse().unwrap())
|
||||
.collect();
|
||||
assert_eq!(result_seq.len(), num_samples, "Miscounted output length!");
|
||||
assert!(
|
||||
result_seq.iter().all(|x| (0..=1234567890).contains(x)),
|
||||
"Output includes element not from range: {}",
|
||||
result.stdout_str()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_very_large_range_offset() {
|
||||
let num_samples = 10;
|
||||
let result = new_ucmd!()
|
||||
.arg("-n")
|
||||
.arg(&num_samples.to_string())
|
||||
.arg("-i1234567890-2147483647")
|
||||
.succeeds();
|
||||
result.no_stderr();
|
||||
|
||||
let result_seq: Vec<isize> = result
|
||||
.stdout_str()
|
||||
.split('\n')
|
||||
.filter(|x| !x.is_empty())
|
||||
.map(|x| x.parse().unwrap())
|
||||
.collect();
|
||||
assert_eq!(result_seq.len(), num_samples, "Miscounted output length!");
|
||||
assert!(
|
||||
result_seq
|
||||
.iter()
|
||||
.all(|x| (1234567890..=2147483647).contains(x)),
|
||||
"Output includes element not from range: {}",
|
||||
result.stdout_str()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_very_high_range_full() {
|
||||
let input_seq = vec![
|
||||
2147483641, 2147483642, 2147483643, 2147483644, 2147483645, 2147483646, 2147483647,
|
||||
];
|
||||
let result = new_ucmd!().arg("-i2147483641-2147483647").succeeds();
|
||||
result.no_stderr();
|
||||
|
||||
let mut result_seq: Vec<isize> = result
|
||||
.stdout_str()
|
||||
.split('\n')
|
||||
.filter(|x| !x.is_empty())
|
||||
.map(|x| x.parse().unwrap())
|
||||
.collect();
|
||||
result_seq.sort_unstable();
|
||||
assert_eq!(result_seq, input_seq, "Output is not a permutation");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_repeat() {
|
||||
let num_samples = 500;
|
||||
let result = new_ucmd!()
|
||||
.arg("-r")
|
||||
.arg("-n")
|
||||
.arg(&num_samples.to_string())
|
||||
.arg("-i12-34")
|
||||
.succeeds();
|
||||
result.no_stderr();
|
||||
|
||||
let result_seq: Vec<isize> = result
|
||||
.stdout_str()
|
||||
.split('\n')
|
||||
.filter(|x| !x.is_empty())
|
||||
.map(|x| x.parse().unwrap())
|
||||
.collect();
|
||||
assert_eq!(result_seq.len(), num_samples, "Miscounted output length!");
|
||||
assert!(
|
||||
result_seq.iter().all(|x| (12..=34).contains(x)),
|
||||
"Output includes element not from range: {}",
|
||||
result.stdout_str()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_input() {
|
||||
let result = new_ucmd!().pipe_in(vec![]).succeeds();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue