From 33c49666c35171afcd9158446b4aa762f58ea1dd Mon Sep 17 00:00:00 2001 From: Sylvestre Ledru Date: Wed, 23 Mar 2022 12:12:54 +0100 Subject: [PATCH] nproc: make tests/misc/nproc-override.sh pass by implementing OMP_NUM_THREADS=X,Y,Z (#3296) + nproc tests: use assert_eq when comparing the two values Co-authored-by: jfinkels --- src/uu/nproc/src/nproc.rs | 21 +++++++++- tests/by-util/test_nproc.rs | 80 ++++++++++++++++++++++++++++++++----- 2 files changed, 89 insertions(+), 12 deletions(-) diff --git a/src/uu/nproc/src/nproc.rs b/src/uu/nproc/src/nproc.rs index 2615e0c66..87fe9a4e7 100644 --- a/src/uu/nproc/src/nproc.rs +++ b/src/uu/nproc/src/nproc.rs @@ -50,7 +50,11 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { let limit = match env::var("OMP_THREAD_LIMIT") { // Uses the OpenMP variable to limit the number of threads // If the parsing fails, returns the max size (so, no impact) - Ok(threadstr) => threadstr.parse().unwrap_or(usize::MAX), + // If OMP_THREAD_LIMIT=0, rejects the value + Ok(threadstr) => match threadstr.parse() { + Ok(0) | Err(_) => usize::MAX, + Ok(n) => n, + }, // the variable 'OMP_THREAD_LIMIT' doesn't exist // fallback to the max Err(_) => usize::MAX, @@ -63,12 +67,25 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { match env::var("OMP_NUM_THREADS") { // Uses the OpenMP variable to force the number of threads // If the parsing fails, returns the number of CPU - Ok(threadstr) => threadstr.parse().unwrap_or_else(|_| num_cpus::get()), + Ok(threadstr) => { + // In some cases, OMP_NUM_THREADS can be "x,y,z" + // In this case, only take the first one (like GNU) + // If OMP_NUM_THREADS=0, rejects the value + let thread: Vec<&str> = threadstr.split_terminator(',').collect(); + match &thread[..] { + [] => num_cpus::get(), + [s, ..] => match s.parse() { + Ok(0) | Err(_) => num_cpus::get(), + Ok(n) => n, + }, + } + } // the variable 'OMP_NUM_THREADS' doesn't exist // fallback to the regular CPU detection Err(_) => num_cpus::get(), } }; + cores = std::cmp::min(limit, cores); if cores <= ignore { cores = 1; diff --git a/tests/by-util/test_nproc.rs b/tests/by-util/test_nproc.rs index 6d3fb1fd0..330f327cb 100644 --- a/tests/by-util/test_nproc.rs +++ b/tests/by-util/test_nproc.rs @@ -20,7 +20,7 @@ fn test_nproc_all_omp() { .succeeds(); let nproc_omp: u8 = result.stdout_str().trim().parse().unwrap(); - assert!(nproc_omp == 60); + assert_eq!(nproc_omp, 60); let result = TestScenario::new(util_name!()) .ucmd_keepenv() @@ -28,7 +28,7 @@ fn test_nproc_all_omp() { .arg("--all") .succeeds(); let nproc_omp: u8 = result.stdout_str().trim().parse().unwrap(); - assert!(nproc == nproc_omp); + assert_eq!(nproc, nproc_omp); // If the parsing fails, returns the number of CPU let result = TestScenario::new(util_name!()) @@ -36,7 +36,7 @@ fn test_nproc_all_omp() { .env("OMP_NUM_THREADS", "incorrectnumber") // returns the number CPU .succeeds(); let nproc_omp: u8 = result.stdout_str().trim().parse().unwrap(); - assert!(nproc == nproc_omp); + assert_eq!(nproc, nproc_omp); } #[test] @@ -51,14 +51,14 @@ fn test_nproc_ignore() { .arg((nproc_total - 1).to_string()) .succeeds(); let nproc: u8 = result.stdout_str().trim().parse().unwrap(); - assert!(nproc == 1); + assert_eq!(nproc, 1); // Ignore all CPU but one with a string let result = TestScenario::new(util_name!()) .ucmd_keepenv() .arg("--ignore= 1") .succeeds(); let nproc: u8 = result.stdout_str().trim().parse().unwrap(); - assert!(nproc_total - 1 == nproc); + assert_eq!(nproc_total - 1, nproc); } } @@ -70,7 +70,7 @@ fn test_nproc_ignore_all_omp() { .arg("--ignore=40") .succeeds(); let nproc: u8 = result.stdout_str().trim().parse().unwrap(); - assert!(nproc == 2); + assert_eq!(nproc, 2); } #[test] @@ -81,7 +81,7 @@ fn test_nproc_omp_limit() { .env("OMP_THREAD_LIMIT", "0") .succeeds(); let nproc: u8 = result.stdout_str().trim().parse().unwrap(); - assert!(nproc == 1); + assert_eq!(nproc, 42); let result = TestScenario::new(util_name!()) .ucmd_keepenv() @@ -89,7 +89,7 @@ fn test_nproc_omp_limit() { .env("OMP_THREAD_LIMIT", "2") .succeeds(); let nproc: u8 = result.stdout_str().trim().parse().unwrap(); - assert!(nproc == 2); + assert_eq!(nproc, 2); let result = TestScenario::new(util_name!()) .ucmd_keepenv() @@ -97,12 +97,72 @@ fn test_nproc_omp_limit() { .env("OMP_THREAD_LIMIT", "2bad") .succeeds(); let nproc: u8 = result.stdout_str().trim().parse().unwrap(); - assert!(nproc == 42); + assert_eq!(nproc, 42); + + let result = new_ucmd!().arg("--all").succeeds(); + let nproc_system: u8 = result.stdout_str().trim().parse().unwrap(); + assert!(nproc_system > 0); let result = TestScenario::new(util_name!()) .ucmd_keepenv() .env("OMP_THREAD_LIMIT", "1") .succeeds(); let nproc: u8 = result.stdout_str().trim().parse().unwrap(); - assert!(nproc == 1); + assert_eq!(nproc, 1); + + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_NUM_THREADS", "0") + .env("OMP_THREAD_LIMIT", "") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert_eq!(nproc, nproc_system); + + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_NUM_THREADS", "") + .env("OMP_THREAD_LIMIT", "") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert_eq!(nproc, nproc_system); + + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_NUM_THREADS", "2,2,1") + .env("OMP_THREAD_LIMIT", "") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert_eq!(2, nproc); + + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_NUM_THREADS", "2,ignored") + .env("OMP_THREAD_LIMIT", "") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert_eq!(2, nproc); + + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_NUM_THREADS", "2,2,1") + .env("OMP_THREAD_LIMIT", "0") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert_eq!(2, nproc); + + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_NUM_THREADS", "2,2,1") + .env("OMP_THREAD_LIMIT", "1bad") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert_eq!(2, nproc); + + let result = TestScenario::new(util_name!()) + .ucmd_keepenv() + .env("OMP_NUM_THREADS", "29,2,1") + .env("OMP_THREAD_LIMIT", "1bad") + .succeeds(); + let nproc: u8 = result.stdout_str().trim().parse().unwrap(); + assert_eq!(29, nproc); }