From e6f59b12f75bc1b44b86be8cbf70d1c94cc61a6e Mon Sep 17 00:00:00 2001 From: Justin Tracey Date: Sun, 6 Feb 2022 01:59:10 -0500 Subject: [PATCH] join: lock and buffer stdout By abstracting the writer we write to, we can lock stdout once at the beginning, then use buffered writes to it throughout. --- src/uu/join/src/join.rs | 145 ++++++++++++++++++++++++++++------------ 1 file changed, 104 insertions(+), 41 deletions(-) diff --git a/src/uu/join/src/join.rs b/src/uu/join/src/join.rs index b8c04925d..cb953c133 100644 --- a/src/uu/join/src/join.rs +++ b/src/uu/join/src/join.rs @@ -16,7 +16,7 @@ use std::convert::From; use std::error::Error; use std::fmt::Display; use std::fs::File; -use std::io::{stdin, stdout, BufRead, BufReader, Split, Stdin, Write}; +use std::io::{stdin, stdout, BufRead, BufReader, BufWriter, Split, Stdin, Write}; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; use uucore::display::Quotable; @@ -144,34 +144,43 @@ impl<'a> Repr<'a> { } /// Print the field or empty filler if the field is not set. - fn print_field(&self, field: Option<&Vec>) -> Result<(), std::io::Error> { + fn print_field( + &self, + writer: &mut impl Write, + field: Option<&Vec>, + ) -> Result<(), std::io::Error> { let value = match field { Some(field) => field, None => self.empty, }; - stdout().write_all(value) + writer.write_all(value) } /// Print each field except the one at the index. - fn print_fields(&self, line: &Line, index: usize) -> Result<(), std::io::Error> { + fn print_fields( + &self, + writer: &mut impl Write, + line: &Line, + index: usize, + ) -> Result<(), std::io::Error> { for i in 0..line.fields.len() { if i != index { - stdout().write_all(&[self.separator])?; - stdout().write_all(&line.fields[i])?; + writer.write_all(&[self.separator])?; + writer.write_all(&line.fields[i])?; } } Ok(()) } /// Print each field or the empty filler if the field is not set. - fn print_format(&self, f: F) -> Result<(), std::io::Error> + fn print_format(&self, writer: &mut impl Write, f: F) -> Result<(), std::io::Error> where F: Fn(&Spec) -> Option<&'a Vec>, { for i in 0..self.format.len() { if i > 0 { - stdout().write_all(&[self.separator])?; + writer.write_all(&[self.separator])?; } let field = match f(&self.format[i]) { @@ -179,13 +188,13 @@ impl<'a> Repr<'a> { None => self.empty, }; - stdout().write_all(field)?; + writer.write_all(field)?; } Ok(()) } - fn print_line_ending(&self) -> Result<(), std::io::Error> { - stdout().write_all(&[self.line_ending as u8]) + fn print_line_ending(&self, writer: &mut impl Write) -> Result<(), std::io::Error> { + writer.write_all(&[self.line_ending as u8]) } } @@ -342,9 +351,14 @@ impl<'a> State<'a> { } /// Skip the current unpaired line. - fn skip_line(&mut self, input: &Input, repr: &Repr) -> Result<(), JoinError> { + fn skip_line( + &mut self, + writer: &mut impl Write, + input: &Input, + repr: &Repr, + ) -> Result<(), JoinError> { if self.print_unpaired { - self.print_first_line(repr)?; + self.print_first_line(writer, repr)?; } self.reset_next_line(input)?; @@ -368,28 +382,38 @@ impl<'a> State<'a> { } /// Print lines in the buffers as headers. - fn print_headers(&self, other: &State, repr: &Repr) -> Result<(), std::io::Error> { + fn print_headers( + &self, + writer: &mut impl Write, + other: &State, + repr: &Repr, + ) -> Result<(), std::io::Error> { if self.has_line() { if other.has_line() { - self.combine(other, repr)?; + self.combine(writer, other, repr)?; } else { - self.print_first_line(repr)?; + self.print_first_line(writer, repr)?; } } else if other.has_line() { - other.print_first_line(repr)?; + other.print_first_line(writer, repr)?; } Ok(()) } /// Combine two line sequences. - fn combine(&self, other: &State, repr: &Repr) -> Result<(), std::io::Error> { + fn combine( + &self, + writer: &mut impl Write, + other: &State, + repr: &Repr, + ) -> Result<(), std::io::Error> { let key = self.get_current_key(); for line1 in &self.seq { for line2 in &other.seq { if repr.uses_format() { - repr.print_format(|spec| match *spec { + repr.print_format(writer, |spec| match *spec { Spec::Key => key, Spec::Field(file_num, field_num) => { if file_num == self.file_num { @@ -404,12 +428,12 @@ impl<'a> State<'a> { } })?; } else { - repr.print_field(key)?; - repr.print_fields(line1, self.key)?; - repr.print_fields(line2, other.key)?; + repr.print_field(writer, key)?; + repr.print_fields(writer, line1, self.key)?; + repr.print_fields(writer, line2, other.key)?; } - repr.print_line_ending()?; + repr.print_line_ending(writer)?; } } @@ -452,16 +476,21 @@ impl<'a> State<'a> { 0 } - fn finalize(&mut self, input: &Input, repr: &Repr) -> Result<(), JoinError> { + fn finalize( + &mut self, + writer: &mut impl Write, + input: &Input, + repr: &Repr, + ) -> Result<(), JoinError> { if self.has_line() { if self.print_unpaired { - self.print_first_line(repr)?; + self.print_first_line(writer, repr)?; } let mut next_line = self.next_line(input)?; while let Some(line) = &next_line { if self.print_unpaired { - self.print_line(line, repr)?; + self.print_line(writer, line, repr)?; } self.reset(next_line); next_line = self.next_line(input)?; @@ -522,9 +551,14 @@ impl<'a> State<'a> { self.seq[0].get_field(self.key) } - fn print_line(&self, line: &Line, repr: &Repr) -> Result<(), std::io::Error> { + fn print_line( + &self, + writer: &mut impl Write, + line: &Line, + repr: &Repr, + ) -> Result<(), std::io::Error> { if repr.uses_format() { - repr.print_format(|spec| match *spec { + repr.print_format(writer, |spec| match *spec { Spec::Key => line.get_field(self.key), Spec::Field(file_num, field_num) => { if file_num == self.file_num { @@ -535,15 +569,15 @@ impl<'a> State<'a> { } })?; } else { - repr.print_field(line.get_field(self.key))?; - repr.print_fields(line, self.key)?; + repr.print_field(writer, line.get_field(self.key))?; + repr.print_fields(writer, line, self.key)?; } - repr.print_line_ending() + repr.print_line_ending(writer) } - fn print_first_line(&self, repr: &Repr) -> Result<(), std::io::Error> { - self.print_line(&self.seq[0], repr) + fn print_first_line(&self, writer: &mut impl Write, repr: &Repr) -> Result<(), std::io::Error> { + self.print_line(writer, &self.seq[0], repr) } } @@ -816,8 +850,11 @@ fn exec(file1: &str, file2: &str, settings: Settings) -> Result<(), JoinError> { &settings.empty, ); + let stdout = stdout(); + let mut writer = BufWriter::new(stdout.lock()); + if settings.headers { - state1.print_headers(&state2, &repr)?; + state1.print_headers(&mut writer, &state2, &repr)?; state1.reset_read_line(&input)?; state2.reset_read_line(&input)?; } @@ -827,21 +864,39 @@ fn exec(file1: &str, file2: &str, settings: Settings) -> Result<(), JoinError> { match diff { Ordering::Less => { - state1.skip_line(&input, &repr)?; + if let Err(e) = state1.skip_line(&mut writer, &input, &repr) { + writer.flush()?; + return Err(e); + } state1.has_unpaired = true; state2.has_unpaired = true; } Ordering::Greater => { - state2.skip_line(&input, &repr)?; + if let Err(e) = state2.skip_line(&mut writer, &input, &repr) { + writer.flush()?; + return Err(e); + } state1.has_unpaired = true; state2.has_unpaired = true; } Ordering::Equal => { - let next_line1 = state1.extend(&input)?; - let next_line2 = state2.extend(&input)?; + let next_line1 = match state1.extend(&input) { + Ok(line) => line, + Err(e) => { + writer.flush()?; + return Err(e); + } + }; + let next_line2 = match state2.extend(&input) { + Ok(line) => line, + Err(e) => { + writer.flush()?; + return Err(e); + } + }; if settings.print_joined { - state1.combine(&state2, &repr)?; + state1.combine(&mut writer, &state2, &repr)?; } state1.reset(next_line1); @@ -850,8 +905,16 @@ fn exec(file1: &str, file2: &str, settings: Settings) -> Result<(), JoinError> { } } - state1.finalize(&input, &repr)?; - state2.finalize(&input, &repr)?; + if let Err(e) = state1.finalize(&mut writer, &input, &repr) { + writer.flush()?; + return Err(e); + }; + if let Err(e) = state2.finalize(&mut writer, &input, &repr) { + writer.flush()?; + return Err(e); + }; + + writer.flush()?; if state1.has_failed || state2.has_failed { eprintln!(