diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 53cf0aed71..7882330d08 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -26,7 +26,7 @@ joblib=dict(n_jobs=-1, backend="loky"), processpoolexecutor=dict(max_workers=2, mp_context=None), dask=dict(client=None), - slurm=dict(tmp_script_folder=None, sbatch_args=dict(cpus_per_task=1, mem="1G")), + slurm={"tmp_script_folder": None, "sbatch_executable_path": "sbatch", "cpus-per-task": 1, "mem": "1G"}, ) @@ -67,10 +67,16 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs=None, return_output=F The engine to run the list. * "loop" : a simple loop. This engine is engine_kwargs : dict - In the case of engine="slum", arguments to sbatch can be passed via sbatch_args, which is a dictionary whose - keys correspond to the --args to be passed to sbatch. - - return_output : bool, dfault False + In the case of engine="slum", possible kwargs are: + - tmp_script_folder: str, default None + the folder in which the job scripts are created. Default: directory created by + the `tempfile` library + - sbatch_executable_path: str, default 'sbatch' + the path to the `sbatch` executable + - other kwargs are interpreted as arguments to sbatch, and are translated to the --args to be passed to sbatch. + see the [documentation for `sbatch`](https://slurm.schedmd.com/sbatch.html) for a list of possible arguments + + return_output : bool, default False Return a sortings or None. This also overwrite kwargs in run_sorter(with_sorting=True/False) @@ -82,7 +88,8 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs=None, return_output=F assert engine in _implemented_engine, f"engine must be in {_implemented_engine}" - engine_kwargs = {} if None else engine_kwargs + if engine_kwargs is None: + engine_kwargs = dict() engine_kwargs_ = dict() engine_kwargs_.update(_default_engine_kwargs[engine]) engine_kwargs_.update(engine_kwargs) @@ -148,10 +155,18 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs=None, return_output=F elif engine == "slurm": # generate python script for slurm - tmp_script_folder = engine_kwargs["tmp_script_folder"] + tmp_script_folder = engine_kwargs.pop("tmp_script_folder") if tmp_script_folder is None: tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_") tmp_script_folder = Path(tmp_script_folder) + sbatch_executable = engine_kwargs.pop("sbatch_executable_path") + + # for backward compatibility with previous version + if "cpus_per_task" in engine_kwargs: + warnings.warn("cpus_per_task is deprecated, use cpus-per-task instead", DeprecationWarning) + cpus_per_task = engine_kwargs.pop("cpus_per_task") + if "cpus-per-task" not in engine_kwargs: + engine_kwargs["cpus-per-task"] = cpus_per_task tmp_script_folder.mkdir(exist_ok=True, parents=True) @@ -180,8 +195,14 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs=None, return_output=F ) f.write(slurm_script) os.fchmod(f.fileno(), mode=stat.S_IRWXU) - sbatch_args = ' '.join(['--{k}={v}' for k, v in engine_kwargs['sbatch_args'].items()]) - subprocess.Popen("sbatch", str(script_name.absolute()), sbatch_args) + + progr = [sbatch_executable] + for k, v in engine_kwargs.items(): + progr.append(f"--{k}") + progr.append(f"{v}") + progr.append(str(script_name.absolute())) + p = subprocess.run(progr, capture_output=True, text=True) + print(p.stdout) return out