Skip to content

Commit

Permalink
cleaned up code
Browse files Browse the repository at this point in the history
  • Loading branch information
MarinManuel committed Jun 29, 2024
1 parent ebabfec commit 4e59f23
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions src/spikeinterface/sorters/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
joblib=dict(n_jobs=-1, backend="loky"),
processpoolexecutor=dict(max_workers=2, mp_context=None),
dask=dict(client=None),
slurm=dict(tmp_script_folder=None, sbatch_args=dict(cpus_per_task=1, mem="1G")),
slurm={"tmp_script_folder": None, "sbatch_executable_path": "sbatch", "cpus-per-task": 1, "mem": "1G"},
)


Expand Down Expand Up @@ -67,10 +67,16 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs=None, return_output=F
The engine to run the list.
* "loop" : a simple loop. This engine is
engine_kwargs : dict
In the case of engine="slum", arguments to sbatch can be passed via sbatch_args, which is a dictionary whose
keys correspond to the --args to be passed to sbatch.
return_output : bool, dfault False
In the case of engine="slum", possible kwargs are:
- tmp_script_folder: str, default None
the folder in which the job scripts are created. Default: directory created by
the `tempfile` library
- sbatch_executable_path: str, default 'sbatch'
the path to the `sbatch` executable
- other kwargs are interpreted as arguments to sbatch, and are translated to the --args to be passed to sbatch.
see the [documentation for `sbatch`](https://slurm.schedmd.com/sbatch.html) for a list of possible arguments
return_output : bool, default False
Return a sortings or None.
This also overwrite kwargs in run_sorter(with_sorting=True/False)
Expand All @@ -82,7 +88,8 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs=None, return_output=F

assert engine in _implemented_engine, f"engine must be in {_implemented_engine}"

engine_kwargs = {} if None else engine_kwargs
if engine_kwargs is None:
engine_kwargs = dict()
engine_kwargs_ = dict()
engine_kwargs_.update(_default_engine_kwargs[engine])
engine_kwargs_.update(engine_kwargs)
Expand Down Expand Up @@ -148,10 +155,18 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs=None, return_output=F

elif engine == "slurm":
# generate python script for slurm
tmp_script_folder = engine_kwargs["tmp_script_folder"]
tmp_script_folder = engine_kwargs.pop("tmp_script_folder")
if tmp_script_folder is None:
tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_")
tmp_script_folder = Path(tmp_script_folder)
sbatch_executable = engine_kwargs.pop("sbatch_executable_path")

# for backward compatibility with previous version
if "cpus_per_task" in engine_kwargs:
warnings.warn("cpus_per_task is deprecated, use cpus-per-task instead", DeprecationWarning)
cpus_per_task = engine_kwargs.pop("cpus_per_task")
if "cpus-per-task" not in engine_kwargs:
engine_kwargs["cpus-per-task"] = cpus_per_task

tmp_script_folder.mkdir(exist_ok=True, parents=True)

Expand Down Expand Up @@ -180,8 +195,14 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs=None, return_output=F
)
f.write(slurm_script)
os.fchmod(f.fileno(), mode=stat.S_IRWXU)
sbatch_args = ' '.join(['--{k}={v}' for k, v in engine_kwargs['sbatch_args'].items()])
subprocess.Popen("sbatch", str(script_name.absolute()), sbatch_args)

progr = [sbatch_executable]
for k, v in engine_kwargs.items():
progr.append(f"--{k}")
progr.append(f"{v}")
progr.append(str(script_name.absolute()))
p = subprocess.run(progr, capture_output=True, text=True)
print(p.stdout)

return out

Expand Down

0 comments on commit 4e59f23

Please sign in to comment.