Skip to content

Commit

Permalink
fixed test
Browse files Browse the repository at this point in the history
  • Loading branch information
MarinManuel committed Sep 11, 2024
1 parent 6ba8423 commit 0ec9af5
Showing 1 changed file with 6 additions and 80 deletions.
86 changes: 6 additions & 80 deletions src/spikeinterface/sorters/tests/test_launcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import shutil
import tempfile
import time

import pytest
from pathlib import Path

Expand Down Expand Up @@ -126,7 +126,6 @@ def test_run_sorter_jobs_slurm(job_list, create_cache_folder):
)


@pytest.mark.skip("Slurm launcher need a machine with slurm")
def test_run_sorter_jobs_slurm_kwargs(mocker, tmp_path, job_list):
"""
Mock `subprocess.run()` to check that engine_kwargs are
Expand All @@ -141,92 +140,19 @@ def test_run_sorter_jobs_slurm_kwargs(mocker, tmp_path, job_list):
tmp_script_folder = tmp_path / "slurm_scripts"

engine_kwargs = dict(
tmp_script_folder=tmp_script_folder,
sbatch_args={
tmp_script_folder=tmp_script_folder)
slurm_kwargs={
"cpus-per-task": 32,
"mem": "32G",
"gres": "gpu:1",
"any_random_kwarg": 12322,
},
)
run_sorter_jobs(
job_list,
engine="slurm",
engine_kwargs=engine_kwargs,
)

script_0_path = f"{tmp_script_folder}/si_script_0.py"
script_1_path = f"{tmp_script_folder}/si_script_1.py"

expected_command = [
"sbatch",
"--cpus-per-task",
"32",
"--mem",
"32G",
"--gres",
"gpu:1",
"--any_random_kwarg",
"12322",
script_1_path,
]
mock_subprocess_run.assert_called_with(expected_command, capture_output=True, text=True)

# Next, check the fisrt call (which sets up `si_script_0.py`)
# also has the expected arguments.
expected_command[9] = script_0_path
assert mock_subprocess_run.call_args_list[0].args[0] == expected_command

# Next, check that defaults are used properly when no kwargs are
# passed. This will default to `_default_engine_kwargs` as
# set in `launcher.py`
run_sorter_jobs(
job_list,
engine="slurm",
engine_kwargs={"tmp_script_folder": tmp_script_folder},
)
expected_command = ["sbatch", "--cpus-per-task", "1", "--mem", "1G", script_1_path]
mock_subprocess_run.assert_called_with(expected_command, capture_output=True, text=True)

# Finally, check that the `tmp_script_folder` is generated on the
# fly as expected. A random foldername is generated, just check that
# the folder to which the scripts are saved is in the `tempfile` format.
run_sorter_jobs(
job_list,
engine="slurm",
engine_kwargs=None, # TODO: test defaults
)
tmp_script_folder = "_".join(tempfile.mkdtemp(prefix="spikeinterface_slurm_").split("_")[:-1])
assert tmp_script_folder in mock_subprocess_run.call_args_list[-1].args[0][5]


@pytest.mark.skip("Slurm launcher need a machine with slurm")
def test_run_sorter_jobs_slurm_kwargs(mocker, tmp_path, job_list):
"""
Mock `subprocess.run()` to check that engine_kwargs are
propagated to the call as expected.
"""
# First, mock `subprocess.run()`, set up a call to `run_sorter_jobs`
# then check the mocked `subprocess.run()` was called with the
# expected signature. Two jobs are passed in `jobs_list`, first
# check the most recent call.
mock_subprocess_run = mocker.patch("spikeinterface.sorters.launcher.subprocess.run")

tmp_script_folder = tmp_path / "slurm_scripts"
}

engine_kwargs = dict(
tmp_script_folder=tmp_script_folder,
sbatch_args={
"cpus-per-task": 32,
"mem": "32G",
"gres": "gpu:1",
"any_random_kwarg": 12322,
},
)
run_sorter_jobs(
job_list,
engine="slurm",
engine_kwargs=engine_kwargs,
slurm_kwargs=slurm_kwargs
)

script_0_path = f"{tmp_script_folder}/si_script_0.py"
Expand Down Expand Up @@ -268,7 +194,7 @@ def test_run_sorter_jobs_slurm_kwargs(mocker, tmp_path, job_list):
run_sorter_jobs(
job_list,
engine="slurm",
engine_kwargs=None, # TODO: test defaults
engine_kwargs=None,
)
tmp_script_folder = "_".join(tempfile.mkdtemp(prefix="spikeinterface_slurm_").split("_")[:-1])
assert tmp_script_folder in mock_subprocess_run.call_args_list[-1].args[0][5]
Expand Down

0 comments on commit 0ec9af5

Please sign in to comment.