From 971140dfed8b4dfa7143752f49fc622b63019d71 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 10 May 2024 08:41:44 -0400 Subject: [PATCH 1/4] check for float > 0 for n_jobs --- src/spikeinterface/core/job_tools.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index db8d8f6339..257def062c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -87,7 +87,12 @@ def fix_job_kwargs(runtime_job_kwargs): n_jobs = job_kwargs["n_jobs"] assert isinstance(n_jobs, (float, np.integer, int)) if isinstance(n_jobs, float): - n_jobs = int(n_jobs * os.cpu_count()) + # if n_jobs is > 1 then it is something like 4.0 cast to int + if n_jobs > 1: + n_jobs = int(n_jobs) + # otherwise it is a fraction so we should do the fraction of os.cpu_count() + else: + n_jobs = int(n_jobs * os.cpu_count()) elif n_jobs < 0: n_jobs = os.cpu_count() + 1 + n_jobs job_kwargs["n_jobs"] = max(n_jobs, 1) From 173873715e24cfe94734ab9673fb210205bac607 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 10 May 2024 08:48:06 -0400 Subject: [PATCH 2/4] two more instances of preventing problem with float --- src/spikeinterface/core/job_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 257def062c..06e60468d3 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -88,13 +88,13 @@ def fix_job_kwargs(runtime_job_kwargs): assert isinstance(n_jobs, (float, np.integer, int)) if isinstance(n_jobs, float): # if n_jobs is > 1 then it is something like 4.0 cast to int - if n_jobs > 1: + if n_jobs >= 1.0: n_jobs = int(n_jobs) # otherwise it is a fraction so we should do the fraction of os.cpu_count() else: n_jobs = int(n_jobs * os.cpu_count()) elif n_jobs < 0: - n_jobs = os.cpu_count() + 1 + n_jobs + n_jobs = int(os.cpu_count() + 1 + n_jobs) job_kwargs["n_jobs"] = max(n_jobs, 1) return job_kwargs From 57974f4d05308cba6a693316240edee0ecbbe46e Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 10 May 2024 09:26:24 -0400 Subject: [PATCH 3/4] simplify if-else --- src/spikeinterface/core/job_tools.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 06e60468d3..f1454d9023 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -85,16 +85,18 @@ def fix_job_kwargs(runtime_job_kwargs): # if n_jobs is -1, set to os.cpu_count() (n_jobs is always in global job_kwargs) n_jobs = job_kwargs["n_jobs"] - assert isinstance(n_jobs, (float, np.integer, int)) - if isinstance(n_jobs, float): - # if n_jobs is > 1 then it is something like 4.0 cast to int - if n_jobs >= 1.0: - n_jobs = int(n_jobs) - # otherwise it is a fraction so we should do the fraction of os.cpu_count() - else: - n_jobs = int(n_jobs * os.cpu_count()) + assert isinstance(n_jobs, (float, np.integer, int)) and n_jobs != 0, "n_jobs must be a non-zero int or float" + + # for a fraction we do fraction of total cores + if isinstance(n_jobs, float) and 0 < n_jobs <= 1: + n_jobs = int(n_jobs * os.cpu_count()) + # for negative numbers we count down from total cores (with -1 being all) elif n_jobs < 0: n_jobs = int(os.cpu_count() + 1 + n_jobs) + # otherwise we just take the value given + else: + n_jobs = int(n_jobs) + job_kwargs["n_jobs"] = max(n_jobs, 1) return job_kwargs From 7c41dba226b36c42eaec7f39c6e29e473c7b5ae2 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 10 May 2024 09:36:32 -0400 Subject: [PATCH 4/4] update test error on 0 and handle float --- src/spikeinterface/core/tests/test_job_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 1bfe3a5e79..26ca8e7e70 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -180,10 +180,10 @@ def test_fix_job_kwargs(): else: assert fixed_job_kwargs["n_jobs"] == 1 - # test minimum n_jobs - job_kwargs = dict(n_jobs=0, progress_bar=False, chunk_duration="1s") + # test float value > 1 is cast to correct int + job_kwargs = dict(n_jobs=4.0, progress_bar=False, chunk_duration="1s") fixed_job_kwargs = fix_job_kwargs(job_kwargs) - assert fixed_job_kwargs["n_jobs"] == 1 + assert fixed_job_kwargs["n_jobs"] == 4 # test wrong keys with pytest.raises(AssertionError):