Skip to content

Commit

Permalink
Merge pull request #2829 from zm711/n_jobs_check
Browse files Browse the repository at this point in the history
Add extra check in fix_job_kwargs
  • Loading branch information
alejoe91 authored May 10, 2024
2 parents ba3dfb4 + 7c41dba commit 707f78a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
13 changes: 10 additions & 3 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,18 @@ def fix_job_kwargs(runtime_job_kwargs):

# if n_jobs is -1, set to os.cpu_count() (n_jobs is always in global job_kwargs)
n_jobs = job_kwargs["n_jobs"]
assert isinstance(n_jobs, (float, np.integer, int))
if isinstance(n_jobs, float):
assert isinstance(n_jobs, (float, np.integer, int)) and n_jobs != 0, "n_jobs must be a non-zero int or float"

# for a fraction we do fraction of total cores
if isinstance(n_jobs, float) and 0 < n_jobs <= 1:
n_jobs = int(n_jobs * os.cpu_count())
# for negative numbers we count down from total cores (with -1 being all)
elif n_jobs < 0:
n_jobs = os.cpu_count() + 1 + n_jobs
n_jobs = int(os.cpu_count() + 1 + n_jobs)
# otherwise we just take the value given
else:
n_jobs = int(n_jobs)

job_kwargs["n_jobs"] = max(n_jobs, 1)

return job_kwargs
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ def test_fix_job_kwargs():
else:
assert fixed_job_kwargs["n_jobs"] == 1

# test minimum n_jobs
job_kwargs = dict(n_jobs=0, progress_bar=False, chunk_duration="1s")
# test float value > 1 is cast to correct int
job_kwargs = dict(n_jobs=4.0, progress_bar=False, chunk_duration="1s")
fixed_job_kwargs = fix_job_kwargs(job_kwargs)
assert fixed_job_kwargs["n_jobs"] == 1
assert fixed_job_kwargs["n_jobs"] == 4

# test wrong keys
with pytest.raises(AssertionError):
Expand Down

0 comments on commit 707f78a

Please sign in to comment.