diff --git a/src/shuf/deps.mk b/src/shuf/deps.mk new file mode 100644 index 000000000..ea0f8a6b8 --- /dev/null +++ b/src/shuf/deps.mk @@ -0,0 +1 @@ +DEPLIBS += rand diff --git a/src/shuf/shuf.rs b/src/shuf/shuf.rs index c2265f97a..2af249287 100644 --- a/src/shuf/shuf.rs +++ b/src/shuf/shuf.rs @@ -1,5 +1,5 @@ #![crate_name = "shuf"] -#![feature(collections, core, old_io, old_path, rand, rustc_private)] +#![feature(rustc_private)] /* * This file is part of the uutils coreutils package. @@ -12,13 +12,13 @@ extern crate getopts; extern crate libc; +extern crate rand; -use std::cmp; -use std::old_io as io; -use std::old_io::IoResult; -use std::iter::{range_inclusive, RangeInclusive}; -use std::rand::{self, Rng}; -use std::usize; +use rand::read::ReadRng; +use rand::{Rng, ThreadRng}; +use std::fs::File; +use std::io::{stdin, stdout, BufReader, BufWriter, Read, Write}; +use std::usize::MAX as MAX_USIZE; #[path = "../common/util.rs"] #[macro_use] @@ -27,15 +27,13 @@ mod util; enum Mode { Default, Echo, - InputRange(RangeInclusive) + InputRange((usize, usize)) } static NAME: &'static str = "shuf"; static VERSION: &'static str = "0.0.1"; pub fn uumain(args: Vec) -> i32 { - let program = args[0].clone(); - let opts = [ getopts::optflag("e", "echo", "treat each ARG as an input line"), getopts::optopt("i", "input-range", "treat each number LO through HI as an input line", "LO-HI"), @@ -47,7 +45,7 @@ pub fn uumain(args: Vec) -> i32 { getopts::optflag("h", "help", "display this help and exit"), getopts::optflag("V", "version", "output version information and exit") ]; - let mut matches = match getopts::getopts(args.tail(), &opts) { + let mut matches = match getopts::getopts(&args[1..], &opts) { Ok(m) => m, Err(f) => { crash!(1, "{}", f) @@ -62,7 +60,7 @@ Usage: {prog} -i LO-HI [OPTION]...\n {usage} With no FILE, or when FILE is -, read standard input.", - name = NAME, version = VERSION, prog = program, + name = NAME, version = VERSION, prog = &args[0][..], usage = getopts::usage("Write a random permutation of the input lines to standard output.", &opts)); } else if matches.opt_present("version") { println!("{} v{}", NAME, VERSION); @@ -76,10 +74,9 @@ With no FILE, or when FILE is -, read standard input.", } match parse_range(range) { Ok(m) => Mode::InputRange(m), - Err((msg, code)) => { - show_error!("{}", msg); - return code; - } + Err(msg) => { + crash!(1, "{}", msg); + }, } } None => { @@ -88,13 +85,19 @@ With no FILE, or when FILE is -, read standard input.", } else { if matches.free.len() == 0 { matches.free.push("-".to_string()); + } else if matches.free.len() > 1 { + show_error!("extra operand '{}'", &matches.free[1][..]); } Mode::Default } } }; let repeat = matches.opt_present("repeat"); - let zero = matches.opt_present("zero-terminated"); + let sep = if matches.opt_present("zero-terminated") { + 0x00 as u8 + } else { + 0x0a as u8 + }; let count = match matches.opt_str("head-count") { Some(cnt) => match cnt.parse::() { Ok(val) => val, @@ -103,102 +106,166 @@ With no FILE, or when FILE is -, read standard input.", return 1; } }, - None => usize::MAX + None => MAX_USIZE, }; let output = matches.opt_str("output"); let random = matches.opt_str("random-source"); - match shuf(matches.free, mode, repeat, zero, count, output, random) { - Err(f) => { - show_error!("{}", f); - return 1; + + match mode { + Mode::Echo => { + // XXX: this doesn't correctly handle non-UTF-8 cmdline args + let mut evec = matches.free.iter().map(|a| a.as_bytes()).collect::>(); + find_seps(&mut evec, sep); + shuf_bytes(&mut evec, repeat, count, sep, output, random); }, - _ => {} + Mode::InputRange((b, e)) => { + let rvec = (b..e).map(|x| format!("{}", x)).collect::>(); + let mut rvec = rvec.iter().map(|a| a.as_bytes()).collect::>(); + shuf_bytes(&mut rvec, repeat, count, sep, output, random); + }, + Mode::Default => { + let fdata = read_input_file(&matches.free[0][..]); + let mut fdata = vec!(&fdata[..]); + find_seps(&mut fdata, sep); + shuf_bytes(&mut fdata, repeat, count, sep, output, random); + } } } 0 } -fn shuf(input: Vec, mode: Mode, repeat: bool, zero: bool, count: usize, output: Option, random: Option) -> IoResult<()> { - match mode { - Mode::Echo => shuf_lines(input, repeat, zero, count, output, random), - Mode::InputRange(range) => shuf_lines(range.map(|num| num.to_string()).collect(), repeat, zero, count, output, random), - Mode::Default => { - let lines: Vec = input.into_iter().flat_map(|filename| { - let slice = filename.as_slice(); - let mut file_buf; - let mut stdin_buf; - let mut file = io::BufferedReader::new( - if slice == "-" { - stdin_buf = io::stdio::stdin_raw(); - &mut stdin_buf as &mut Reader - } else { - file_buf = crash_if_err!(1, io::File::open(&Path::new(slice))); - &mut file_buf as &mut Reader - } - ); - let mut lines = vec!(); - for line in file.lines() { - let mut line = crash_if_err!(1, line); - line.pop(); - lines.push(line); +fn read_input_file(filename: &str) -> Vec { + let mut file = BufReader::new( + if filename == "-" { + Box::new(stdin()) as Box + } else { + match File::open(filename) { + Ok(f) => Box::new(f) as Box, + Err(e) => crash!(1, "failed to open '{}': {}", filename, e), + } + }); + + let mut data = Vec::new(); + match file.read_to_end(&mut data) { + Err(e) => crash!(1, "failed reading '{}': {}", filename, e), + Ok(_) => (), + }; + + data +} + +fn find_seps(data: &mut Vec<&[u8]>, sep: u8) { + // need to use for loop so we don't borrow the vector as we modify it in place + // basic idea: + // * We don't care about the order of the result. This lets us slice the slices + // without making a new vector. + // * Starting from the end of the vector, we examine each element. + // * If that element contains the separator, we remove it from the vector, + // and then sub-slice it into slices that do not contain the separator. + // * We maintain the invariant throughout that each element in the vector past + // the ith element does not have any separators remaining. + for i in (0..data.len()).rev() { + if data[i].contains(&sep) { + let this = data.swap_remove(i); + let mut p = 0; + let mut i = 1; + loop { + if i == this.len() { + break; } - lines.into_iter() - }).collect(); - shuf_lines(lines, repeat, zero, count, output, random) + + if this[i] == sep { + data.push(&this[p..i]); + p = i + 1; + } + i += 1; + } + if p < this.len() { + data.push(&this[p..i]); + } } } } +fn shuf_bytes(input: &mut Vec<&[u8]>, repeat: bool, count: usize, sep: u8, output: Option, random: Option) { + let mut output = BufWriter::new( + match output { + None => Box::new(stdout()) as Box, + Some(s) => match File::create(&s[..]) { + Ok(f) => Box::new(f) as Box, + Err(e) => crash!(1, "failed to open '{}' for writing: {}", &s[..], e), + }, + }); + + let mut rng = match random { + Some(r) => WrappedRng::RngFile(rand::read::ReadRng::new(match File::open(&r[..]) { + Ok(f) => f, + Err(e) => crash!(1, "failed to open random source '{}': {}", &r[..], e), + })), + None => WrappedRng::RngDefault(rand::thread_rng()), + }; + + // we're generating a random usize. To keep things fair, we take this number mod ceil(log2(length+1)) + let mut len_mod = 1; + let mut len = input.len(); + while len > 0 { + len >>= 1; + len_mod <<= 1; + } + drop(len); + + let mut count = count; + while count > 0 && input.len() > 0 { + let mut r = input.len(); + while r >= input.len() { + r = rng.next_usize() % len_mod; + } + + // write the randomly chosen value and the separator + output.write_all(input[r]).unwrap_or_else(|e| crash!(1, "write failed: {}", e)); + output.write_all(&[sep]).unwrap_or_else(|e| crash!(1, "write failed: {}", e)); + + // if we do not allow repeats, remove the chosen value from the input vector + if !repeat { + // shrink the mask if we will drop below a power of 2 + if input.len() % 2 == 0 && len_mod > 2 { + len_mod >>= 1; + } + input.swap_remove(r); + } + + count -= 1; + } +} + +fn parse_range(input_range: String) -> Result<(usize, usize), String> { + let split: Vec<&str> = input_range.split('-').collect(); + if split.len() != 2 { + Err("invalid range format".to_string()) + } else { + let begin = match split[0].parse::() { + Ok(m) => m, + Err(e)=> return Err(format!("{} is not a valid number: {}", split[0], e)), + }; + let end = match split[1].parse::() { + Ok(m) => m, + Err(e)=> return Err(format!("{} is not a valid number: {}", split[1], e)), + }; + Ok((begin, end + 1)) + } +} + enum WrappedRng { - RngFile(rand::reader::ReaderRng), + RngFile(rand::read::ReadRng), RngDefault(rand::ThreadRng), } impl WrappedRng { - fn next_u32(&mut self) -> u32 { + fn next_usize(&mut self) -> usize { match self { - &mut WrappedRng::RngFile(ref mut r) => r.next_u32(), - &mut WrappedRng::RngDefault(ref mut r) => r.next_u32(), + &mut WrappedRng::RngFile(ref mut r) => r.next_u32() as usize, + &mut WrappedRng::RngDefault(ref mut r) => r.next_u32() as usize, } } } - -fn shuf_lines(mut lines: Vec, repeat: bool, zero: bool, count: usize, outname: Option, random: Option) -> IoResult<()> { - let mut output = match outname { - Some(name) => Box::new(io::BufferedWriter::new(try!(io::File::create(&Path::new(name))))) as Box, - None => Box::new(io::stdout()) as Box - }; - let mut rng = match random { - Some(name) => WrappedRng::RngFile(rand::reader::ReaderRng::new(try!(io::File::open(&Path::new(name))))), - None => WrappedRng::RngDefault(rand::thread_rng()), - }; - let mut len = lines.len(); - let max = if repeat { count } else { cmp::min(count, len) }; - for _ in range(0, max) { - let idx = rng.next_u32() as usize % len; - try!(write!(output, "{}{}", lines[idx], if zero { '\0' } else { '\n' })); - if !repeat { - lines.remove(idx); - len -= 1; - } - } - Ok(()) -} - -fn parse_range(input_range: String) -> Result, (String, i32)> { - let split: Vec<&str> = input_range.as_slice().split('-').collect(); - if split.len() != 2 { - Err(("invalid range format".to_string(), 1)) - } else { - let begin = match split[0].parse::() { - Ok(m) => m, - Err(e)=> return Err((format!("{} is not a valid number: {}", split[0], e), 1)) - }; - let end = match split[1].parse::() { - Ok(m) => m, - Err(e)=> return Err((format!("{} is not a valid number: {}", split[1], e), 1)) - }; - Ok(range_inclusive(begin, end)) - } -}