From 4942e519fa104d2d3a7e5ae1257326c448a8eb27 Mon Sep 17 00:00:00 2001 From: Sylvestre Ledru Date: Sun, 20 Mar 2022 23:16:03 +0100 Subject: [PATCH] nproc: add the full support of OMP_THREAD_LIMIT --- src/uu/nproc/src/nproc.rs | 17 ++++++++++++++--- tests/by-util/test_nproc.rs | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/uu/nproc/src/nproc.rs b/src/uu/nproc/src/nproc.rs index fdebea65a..2615e0c66 100644 --- a/src/uu/nproc/src/nproc.rs +++ b/src/uu/nproc/src/nproc.rs @@ -25,7 +25,9 @@ pub const _SC_NPROCESSORS_CONF: libc::c_int = 1001; static OPT_ALL: &str = "all"; static OPT_IGNORE: &str = "ignore"; -static ABOUT: &str = "Print the number of cores available to the current process."; +static ABOUT: &str = r#"Print the number of cores available to the current process. +If the OMP_NUM_THREADS or OMP_THREAD_LIMIT environment variables are set, then +they will determine the minimum and maximum returned value respectively."#; const USAGE: &str = "{} [OPTIONS]..."; #[uucore::main] @@ -45,6 +47,15 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { None => 0, }; + let limit = match env::var("OMP_THREAD_LIMIT") { + // Uses the OpenMP variable to limit the number of threads + // If the parsing fails, returns the max size (so, no impact) + Ok(threadstr) => threadstr.parse().unwrap_or(usize::MAX), + // the variable 'OMP_THREAD_LIMIT' doesn't exist + // fallback to the max + Err(_) => usize::MAX, + }; + let mut cores = if matches.is_present(OPT_ALL) { num_cpus_all() } else { @@ -53,12 +64,12 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { // Uses the OpenMP variable to force the number of threads // If the parsing fails, returns the number of CPU Ok(threadstr) => threadstr.parse().unwrap_or_else(|_| num_cpus::get()), - // the variable 'OMP_NUM_THREADS' doesn't exit + // the variable 'OMP_NUM_THREADS' doesn't exist // fallback to the regular CPU detection Err(_) => num_cpus::get(), } }; - + cores = std::cmp::min(limit, cores); if cores <= ignore { cores = 1; } else { diff --git a/tests/by-util/test_nproc.rs b/tests/by-util/test_nproc.rs index 5657e6b7e..6d3fb1fd0 100644 --- a/tests/by-util/test_nproc.rs +++ b/tests/by-util/test_nproc.rs @@ -72,3 +72,37 @@ fn test_nproc_ignore_all_omp() { let nproc: u8 = result.stdout_str().trim().parse().unwrap(); assert!(nproc == 2); } + +#[test] +fn test_nproc_omp_limit() { + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_NUM_THREADS", "42") + .env("OMP_THREAD_LIMIT", "0") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert!(nproc == 1); + + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_NUM_THREADS", "42") + .env("OMP_THREAD_LIMIT", "2") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert!(nproc == 2); + + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_NUM_THREADS", "42") + .env("OMP_THREAD_LIMIT", "2bad") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert!(nproc == 42); + + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_THREAD_LIMIT", "1") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert!(nproc == 1); +}