diff --git a/common/c_types.rs b/common/c_types.rs index e4a891722..a0f9c735d 100644 --- a/common/c_types.rs +++ b/common/c_types.rs @@ -13,7 +13,7 @@ use self::libc::funcs::posix88::unistd::getgroups; use std::vec::Vec; -use std::io::IoError; +use std::os; use std::ptr::read; use std::str::raw::from_c_str; @@ -85,7 +85,7 @@ extern { pub fn getpwnam(login: *c_char) -> *c_passwd; pub fn getgrouplist(name: *c_char, basegid: gid_t, - groups: *gid_t, + groups: *mut gid_t, ngroups: *mut c_int) -> c_int; pub fn getgrgid(gid: gid_t) -> *c_group; pub fn getgrnam(name: *c_char) -> *c_group; @@ -137,46 +137,61 @@ pub fn get_group(groupname: &str) -> Option { } } -static NGROUPS: i32 = 20; +fn get_group_list(name: *c_char, gid: gid_t) -> Result, int> { + let mut ngroups = 0 as c_int; + + unsafe { getgrouplist(name, gid, 0 as *mut gid_t, &mut ngroups) }; + let mut groups = Vec::from_elem(ngroups as uint, 0 as gid_t); + let err = unsafe { getgrouplist(name, gid, groups.as_mut_ptr(), &mut ngroups) }; + if err == -1 { + Err(os::errno()) + } else { + groups.truncate(ngroups as uint); + Ok(groups) + } +} + +fn get_groups() -> Result, int> { + let ngroups = unsafe { getgroups(0, 0 as *mut gid_t) }; + if ngroups == -1 { + return Err(os::errno()); + } + + let mut groups = Vec::from_elem(ngroups as uint, 0 as gid_t); + let ngroups = unsafe { getgroups(ngroups, groups.as_mut_ptr()) }; + if ngroups == -1 { + Err(os::errno()) + } else { + groups.truncate(ngroups as uint); + Ok(groups) + } +} pub fn group(possible_pw: Option, nflag: bool) { - let mut groups = Vec::with_capacity(NGROUPS as uint); - let mut ngroups; - if possible_pw.is_some() { - ngroups = NGROUPS; - unsafe { - getgrouplist( - possible_pw.unwrap().pw_name, - possible_pw.unwrap().pw_gid, - groups.as_ptr(), - &mut ngroups); - } - } else { - ngroups = unsafe { - getgroups(NGROUPS, groups.as_mut_ptr() as *mut gid_t) - }; - } + let groups = match possible_pw { + Some(pw) => get_group_list(pw.pw_name, pw.pw_gid), + None => get_groups(), + }; - if ngroups < 0 { - crash!(1, "{}", IoError::last_error()); - } - - unsafe { groups.set_len(ngroups as uint) }; - - for &g in groups.iter() { - if nflag { - let group = unsafe { getgrgid(g) }; - if group.is_not_null() { - let name = unsafe { - from_c_str(read(group).gr_name) - }; - print!("{:s} ", name); + match groups { + Err(errno) => + crash!(1, "failed to get group list (errno={:d})", errno), + Ok(groups) => { + for &g in groups.iter() { + if nflag { + let group = unsafe { getgrgid(g) }; + if group.is_not_null() { + let name = unsafe { + from_c_str(read(group).gr_name) + }; + print!("{:s} ", name); + } + } else { + print!("{:u} ", g); + } } - } else { - print!("{:u} ", g); + println!(""); } } - - println!(""); }