From 0ec9af504a99a56bc1d57dc0bd6ff5a799fb9361 Mon Sep 17 00:00:00 2001 From: MANUEL lab <65401298+MarinManuel@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:14:02 -0400 Subject: [PATCH] fixed test --- .../sorters/tests/test_launcher.py | 86 ++----------------- 1 file changed, 6 insertions(+), 80 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index 3177c700d3..66f8b559f4 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -1,7 +1,7 @@ import sys import shutil +import tempfile import time - import pytest from pathlib import Path @@ -126,7 +126,6 @@ def test_run_sorter_jobs_slurm(job_list, create_cache_folder): ) -@pytest.mark.skip("Slurm launcher need a machine with slurm") def test_run_sorter_jobs_slurm_kwargs(mocker, tmp_path, job_list): """ Mock `subprocess.run()` to check that engine_kwargs are @@ -141,92 +140,19 @@ def test_run_sorter_jobs_slurm_kwargs(mocker, tmp_path, job_list): tmp_script_folder = tmp_path / "slurm_scripts" engine_kwargs = dict( - tmp_script_folder=tmp_script_folder, - sbatch_args={ + tmp_script_folder=tmp_script_folder) + slurm_kwargs={ "cpus-per-task": 32, "mem": "32G", "gres": "gpu:1", "any_random_kwarg": 12322, - }, - ) - run_sorter_jobs( - job_list, - engine="slurm", - engine_kwargs=engine_kwargs, - ) - - script_0_path = f"{tmp_script_folder}/si_script_0.py" - script_1_path = f"{tmp_script_folder}/si_script_1.py" - - expected_command = [ - "sbatch", - "--cpus-per-task", - "32", - "--mem", - "32G", - "--gres", - "gpu:1", - "--any_random_kwarg", - "12322", - script_1_path, - ] - mock_subprocess_run.assert_called_with(expected_command, capture_output=True, text=True) - - # Next, check the fisrt call (which sets up `si_script_0.py`) - # also has the expected arguments. - expected_command[9] = script_0_path - assert mock_subprocess_run.call_args_list[0].args[0] == expected_command - - # Next, check that defaults are used properly when no kwargs are - # passed. This will default to `_default_engine_kwargs` as - # set in `launcher.py` - run_sorter_jobs( - job_list, - engine="slurm", - engine_kwargs={"tmp_script_folder": tmp_script_folder}, - ) - expected_command = ["sbatch", "--cpus-per-task", "1", "--mem", "1G", script_1_path] - mock_subprocess_run.assert_called_with(expected_command, capture_output=True, text=True) - - # Finally, check that the `tmp_script_folder` is generated on the - # fly as expected. A random foldername is generated, just check that - # the folder to which the scripts are saved is in the `tempfile` format. - run_sorter_jobs( - job_list, - engine="slurm", - engine_kwargs=None, # TODO: test defaults - ) - tmp_script_folder = "_".join(tempfile.mkdtemp(prefix="spikeinterface_slurm_").split("_")[:-1]) - assert tmp_script_folder in mock_subprocess_run.call_args_list[-1].args[0][5] - - -@pytest.mark.skip("Slurm launcher need a machine with slurm") -def test_run_sorter_jobs_slurm_kwargs(mocker, tmp_path, job_list): - """ - Mock `subprocess.run()` to check that engine_kwargs are - propagated to the call as expected. - """ - # First, mock `subprocess.run()`, set up a call to `run_sorter_jobs` - # then check the mocked `subprocess.run()` was called with the - # expected signature. Two jobs are passed in `jobs_list`, first - # check the most recent call. - mock_subprocess_run = mocker.patch("spikeinterface.sorters.launcher.subprocess.run") - - tmp_script_folder = tmp_path / "slurm_scripts" + } - engine_kwargs = dict( - tmp_script_folder=tmp_script_folder, - sbatch_args={ - "cpus-per-task": 32, - "mem": "32G", - "gres": "gpu:1", - "any_random_kwarg": 12322, - }, - ) run_sorter_jobs( job_list, engine="slurm", engine_kwargs=engine_kwargs, + slurm_kwargs=slurm_kwargs ) script_0_path = f"{tmp_script_folder}/si_script_0.py" @@ -268,7 +194,7 @@ def test_run_sorter_jobs_slurm_kwargs(mocker, tmp_path, job_list): run_sorter_jobs( job_list, engine="slurm", - engine_kwargs=None, # TODO: test defaults + engine_kwargs=None, ) tmp_script_folder = "_".join(tempfile.mkdtemp(prefix="spikeinterface_slurm_").split("_")[:-1]) assert tmp_script_folder in mock_subprocess_run.call_args_list[-1].args[0][5]