1
Fork 0
mirror of https://github.com/RGBCube/uutils-coreutils synced 2025-07-28 11:37:44 +00:00

shred: remove usage of Vec::set_len() (#1738)

* shred: use a fixed-size array for BytesGenerator
This commit is contained in:
Alex Lyon 2021-02-23 03:34:49 -08:00 committed by GitHub
parent bb54669a5d
commit 7341a1a033
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -119,6 +119,7 @@ struct BytesGenerator<'a> {
exact: bool, // if false, every block's size is block_size exact: bool, // if false, every block's size is block_size
gen_type: PassType<'a>, gen_type: PassType<'a>,
rng: Option<RefCell<ThreadRng>>, rng: Option<RefCell<ThreadRng>>,
bytes: [u8; BLOCK_SIZE],
} }
impl<'a> BytesGenerator<'a> { impl<'a> BytesGenerator<'a> {
@ -128,6 +129,8 @@ impl<'a> BytesGenerator<'a> {
_ => None, _ => None,
}; };
let bytes = [0; BLOCK_SIZE];
BytesGenerator { BytesGenerator {
total_bytes, total_bytes,
bytes_generated: Cell::new(0u64), bytes_generated: Cell::new(0u64),
@ -135,25 +138,35 @@ impl<'a> BytesGenerator<'a> {
exact, exact,
gen_type, gen_type,
rng, rng,
bytes,
} }
} }
}
impl<'a> Iterator for BytesGenerator<'a> { pub fn reset(&mut self, total_bytes: u64, gen_type: PassType<'a>) {
type Item = Box<[u8]>; if let PassType::Random = gen_type {
if self.rng.is_none() {
self.rng = Some(RefCell::new(rand::thread_rng()));
}
}
fn next(&mut self) -> Option<Box<[u8]>> { self.total_bytes = total_bytes;
self.gen_type = gen_type;
self.bytes_generated.set(0);
}
pub fn next(&mut self) -> Option<&[u8]> {
// We go over the total_bytes limit when !self.exact and total_bytes isn't a multiple // We go over the total_bytes limit when !self.exact and total_bytes isn't a multiple
// of self.block_size // of self.block_size
if self.bytes_generated.get() >= self.total_bytes { if self.bytes_generated.get() >= self.total_bytes {
return None; return None;
} }
let this_block_size: usize = { let this_block_size = {
if !self.exact { if !self.exact {
self.block_size self.block_size
} else { } else {
let bytes_left: u64 = self.total_bytes - self.bytes_generated.get(); let bytes_left = self.total_bytes - self.bytes_generated.get();
if bytes_left >= self.block_size as u64 { if bytes_left >= self.block_size as u64 {
self.block_size self.block_size
} else { } else {
@ -162,17 +175,12 @@ impl<'a> Iterator for BytesGenerator<'a> {
} }
}; };
let mut bytes: Vec<u8> = Vec::with_capacity(this_block_size); let bytes = &mut self.bytes[..this_block_size];
match self.gen_type { match self.gen_type {
PassType::Random => { PassType::Random => {
// This is ok because the vector was
// allocated with the same capacity
unsafe {
bytes.set_len(this_block_size);
}
let mut rng = self.rng.as_ref().unwrap().borrow_mut(); let mut rng = self.rng.as_ref().unwrap().borrow_mut();
rng.fill(&mut bytes[..]); rng.fill(bytes);
} }
PassType::Pattern(pattern) => { PassType::Pattern(pattern) => {
let skip = { let skip = {
@ -182,10 +190,17 @@ impl<'a> Iterator for BytesGenerator<'a> {
(pattern.len() as u64 % self.bytes_generated.get()) as usize (pattern.len() as u64 % self.bytes_generated.get()) as usize
} }
}; };
// Same range as 0..this_block_size but we start with the right index
for i in skip..this_block_size + skip { // Copy the pattern in chunks rather than simply one byte at a time
let index = i % pattern.len(); let mut i = 0;
bytes.push(pattern[index]); while i < this_block_size {
let start = (i + skip) % pattern.len();
let end = (this_block_size - i).min(pattern.len());
let len = end - start;
bytes[i..i + len].copy_from_slice(&pattern[start..end]);
i += len;
} }
} }
}; };
@ -193,7 +208,7 @@ impl<'a> Iterator for BytesGenerator<'a> {
let new_bytes_generated = self.bytes_generated.get() + this_block_size as u64; let new_bytes_generated = self.bytes_generated.get() + this_block_size as u64;
self.bytes_generated.set(new_bytes_generated); self.bytes_generated.set(new_bytes_generated);
Some(bytes.into_boxed_slice()) Some(bytes)
} }
} }
@ -443,6 +458,10 @@ fn wipe_file(
.open(path) .open(path)
.expect("Failed to open file for writing"); .expect("Failed to open file for writing");
// NOTE: it does not really matter what we set for total_bytes and gen_type here, so just
// use bogus values
let mut generator = BytesGenerator::new(0, PassType::Pattern(&[]), exact);
for (i, pass_type) in pass_sequence.iter().enumerate() { for (i, pass_type) in pass_sequence.iter().enumerate() {
if verbose { if verbose {
let pass_name: String = pass_name(*pass_type); let pass_name: String = pass_name(*pass_type);
@ -467,7 +486,8 @@ fn wipe_file(
} }
} }
// size is an optional argument for exactly how many bytes we want to shred // size is an optional argument for exactly how many bytes we want to shred
do_pass(&mut file, path, *pass_type, size, exact).expect("File write pass failed"); do_pass(&mut file, path, &mut generator, *pass_type, size)
.expect("File write pass failed");
// Ignore failed writes; just keep trying // Ignore failed writes; just keep trying
} }
} }
@ -477,22 +497,22 @@ fn wipe_file(
} }
} }
fn do_pass( fn do_pass<'a>(
file: &mut File, file: &mut File,
path: &Path, path: &Path,
generator_type: PassType, generator: &mut BytesGenerator<'a>,
generator_type: PassType<'a>,
given_file_size: Option<u64>, given_file_size: Option<u64>,
exact: bool,
) -> Result<(), io::Error> { ) -> Result<(), io::Error> {
file.seek(SeekFrom::Start(0))?; file.seek(SeekFrom::Start(0))?;
// Use the given size or the whole file if not specified // Use the given size or the whole file if not specified
let size: u64 = given_file_size.unwrap_or(get_file_size(path)?); let size: u64 = given_file_size.unwrap_or(get_file_size(path)?);
let generator = BytesGenerator::new(size, generator_type, exact); generator.reset(size, generator_type);
for block in generator { while let Some(block) = generator.next() {
file.write_all(&*block)?; file.write_all(block)?;
} }
file.sync_data()?; file.sync_data()?;