From cf5041062a73b1c61d8f15200ade73ea1f1d8bae Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 19:14:51 +0100 Subject: [PATCH 01/29] Run checks for singularity, docker and related python module installations. --- src/spikeinterface/sorters/runsorter.py | 18 ++++++++++++++ src/spikeinterface/sorters/utils/misc.py | 31 ++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index baec6aaac3..44a08a34a7 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -169,6 +169,15 @@ def run_sorter( container_image = None else: container_image = docker_image + + if not has_docker(): + raise RuntimeError("Docker is not installed. Install docker " + "on this machine to run sorting with docker.") + + if not has_docker_python(): + raise RuntimeError("The python `docker` package must be installed." + "Install with `pip install docker`") + else: mode = "singularity" assert not docker_image @@ -176,6 +185,15 @@ def run_sorter( container_image = None else: container_image = singularity_image + + if not has_singularity(): + raise RuntimeError("Singularity is not installed. Install singularity " + "on this machine to run sorting with singularity.") + + if not has_spython(): + raise RuntimeError("The python singularity package must be installed." + "Install with `pip install spython`") + return run_sorter_container( container_image=container_image, mode=mode, diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 0a6b4a986c..a1cf34f059 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +import subprocess # TODO: decide best format for this from subprocess import check_output, CalledProcessError from typing import List, Union @@ -80,3 +81,33 @@ def has_nvidia(): return device_count > 0 except RuntimeError: # Failed to dlopen libcuda.so return False + +def _run_subprocess_silently(command): + output = subprocess.run( + command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + return output + + +def has_docker(): + return self._run_subprocess_silently("docker --version").returncode == 0 + + +def has_singularity(): + return self._run_subprocess_silently("singularity --version").returncode == 0 + + +def has_docker_python(): + try: + import docker + return True + except ImportError: + return False + + +def has_spython(): + try: + import spython + return True + except ImportError: + return False From e49521939f2023c50943afad21a663c3d7822011 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 20:09:03 +0100 Subject: [PATCH 02/29] Add nvidia dependency checks, tidy up. --- src/spikeinterface/sorters/runsorter.py | 17 +++++++++++---- src/spikeinterface/sorters/utils/__init__.py | 2 +- src/spikeinterface/sorters/utils/misc.py | 22 +++++++++++++++++--- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 44a08a34a7..884cba590f 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -19,7 +19,7 @@ from ..core import BaseRecording, NumpySorting, load_extractor from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict -from .utils import SpikeSortingError, has_nvidia +from .utils import SpikeSortingError, has_nvidia, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies from .container_tools import ( find_recording_folders, path_to_unix, @@ -175,7 +175,7 @@ def run_sorter( "on this machine to run sorting with docker.") if not has_docker_python(): - raise RuntimeError("The python `docker` package must be installed." + raise RuntimeError("The python `docker` package must be installed. " "Install with `pip install docker`") else: @@ -191,8 +191,8 @@ def run_sorter( "on this machine to run sorting with singularity.") if not has_spython(): - raise RuntimeError("The python singularity package must be installed." - "Install with `pip install spython`") + raise RuntimeError("The python `spython` package must be installed to " + "run singularity. Install with `pip install spython`") return run_sorter_container( container_image=container_image, @@ -480,6 +480,15 @@ def run_sorter_container( if gpu_capability == "nvidia-required": assert has_nvidia(), "The container requires a NVIDIA GPU capability, but it is not available" extra_kwargs["container_requires_gpu"] = True + + if platform.system() == "Linux" and has_docker_nvidia_installed(): + warn( + f"nvidia-required but none of \n{get_nvidia_docker_dependecies()}\n were found. " + f"This may result in an error being raised during sorting. Try " + "installing `nvidia-container-toolkit`, including setting the " + "configuration steps, if running into errors." + ) + elif gpu_capability == "nvidia-optional": if has_nvidia(): extra_kwargs["container_requires_gpu"] = True diff --git a/src/spikeinterface/sorters/utils/__init__.py b/src/spikeinterface/sorters/utils/__init__.py index 6cad10b211..7f6f3089d4 100644 --- a/src/spikeinterface/sorters/utils/__init__.py +++ b/src/spikeinterface/sorters/utils/__init__.py @@ -1,2 +1,2 @@ from .shellscript import ShellScript -from .misc import SpikeSortingError, get_git_commit, has_nvidia, get_matlab_shell_name, get_bash_path +from .misc import SpikeSortingError, get_git_commit, has_nvidia, get_matlab_shell_name, get_bash_path, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index a1cf34f059..4a900f4485 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -82,6 +82,7 @@ def has_nvidia(): except RuntimeError: # Failed to dlopen libcuda.so return False + def _run_subprocess_silently(command): output = subprocess.run( command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL @@ -90,12 +91,27 @@ def _run_subprocess_silently(command): def has_docker(): - return self._run_subprocess_silently("docker --version").returncode == 0 + return _run_subprocess_silently("docker --version").returncode == 0 def has_singularity(): - return self._run_subprocess_silently("singularity --version").returncode == 0 - + return _run_subprocess_silently("singularity --version").returncode == 0 + +def get_nvidia_docker_dependecies(): + return [ + "nvidia-docker", + "nvidia-docker2", + "nvidia-container-toolkit", + ] + +def has_docker_nvidia_installed(): + all_dependencies = get_nvidia_docker_dependecies() + has_dep = [] + for dep in all_dependencies: + has_dep.append( + _run_subprocess_silently(f"{dep} --version").returncode == 0 + ) + return not any(has_dep) def has_docker_python(): try: From e0656bb86901127c8b1c0f708e4970584e79a40d Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 20:15:48 +0100 Subject: [PATCH 03/29] Add docstrings. --- src/spikeinterface/sorters/utils/misc.py | 44 ++++++++++++++++++------ 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 4a900f4485..66744fbab1 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -84,9 +84,10 @@ def has_nvidia(): def _run_subprocess_silently(command): - output = subprocess.run( - command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) + """ + Run a subprocess command without outputting to stderr or stdout. + """ + output = subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) return output @@ -97,25 +98,45 @@ def has_docker(): def has_singularity(): return _run_subprocess_silently("singularity --version").returncode == 0 + +def has_docker_nvidia_installed(): + """ + On Linux, nvidia has a set of container dependencies + that are required for running GPU in docker. This is a little + complex and is described in more detail in the links below. + To summarise breifly, at least one of the `get_nvidia_docker_dependecies()` + is almost certainly required to run docker with GPU. + + https://github.com/NVIDIA/nvidia-docker/issues/1268 + https://www.howtogeek.com/devops/how-to-use-an-nvidia-gpu-with-docker-containers/ + + Returns + ------- + Whether at least one of the dependencies listed in + `get_nvidia_docker_dependecies()` is installed. + """ + all_dependencies = get_nvidia_docker_dependecies() + has_dep = [] + for dep in all_dependencies: + has_dep.append(_run_subprocess_silently(f"{dep} --version").returncode == 0) + return not any(has_dep) + + def get_nvidia_docker_dependecies(): + """ + See `has_docker_nvidia_installed()` + """ return [ "nvidia-docker", "nvidia-docker2", "nvidia-container-toolkit", ] -def has_docker_nvidia_installed(): - all_dependencies = get_nvidia_docker_dependecies() - has_dep = [] - for dep in all_dependencies: - has_dep.append( - _run_subprocess_silently(f"{dep} --version").returncode == 0 - ) - return not any(has_dep) def has_docker_python(): try: import docker + return True except ImportError: return False @@ -124,6 +145,7 @@ def has_docker_python(): def has_spython(): try: import spython + return True except ImportError: return False From b145b04ac31a8de3d9c9fbfc56b4a9974ce0eb3a Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:21:51 +0100 Subject: [PATCH 04/29] Add tests for runsorter dependencies. --- src/spikeinterface/sorters/runsorter.py | 35 +++-- .../tests/test_runsorter_dependency_checks.py | 144 ++++++++++++++++++ 2 files changed, 170 insertions(+), 9 deletions(-) create mode 100644 src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 884cba590f..5b2e80b83d 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -19,7 +19,18 @@ from ..core import BaseRecording, NumpySorting, load_extractor from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict -from .utils import SpikeSortingError, has_nvidia, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies + +# full import required for monkeypatch testing. +from spikeinterface.sorters.utils import ( + SpikeSortingError, + has_nvidia, + has_docker, + has_docker_python, + has_singularity, + has_spython, + has_docker_nvidia_installed, + get_nvidia_docker_dependecies, +) from .container_tools import ( find_recording_folders, path_to_unix, @@ -171,12 +182,14 @@ def run_sorter( container_image = docker_image if not has_docker(): - raise RuntimeError("Docker is not installed. Install docker " - "on this machine to run sorting with docker.") + raise RuntimeError( + "Docker is not installed. Install docker " "on this machine to run sorting with docker." + ) if not has_docker_python(): - raise RuntimeError("The python `docker` package must be installed. " - "Install with `pip install docker`") + raise RuntimeError( + "The python `docker` package must be installed. " "Install with `pip install docker`" + ) else: mode = "singularity" @@ -187,12 +200,16 @@ def run_sorter( container_image = singularity_image if not has_singularity(): - raise RuntimeError("Singularity is not installed. Install singularity " - "on this machine to run sorting with singularity.") + raise RuntimeError( + "Singularity is not installed. Install singularity " + "on this machine to run sorting with singularity." + ) if not has_spython(): - raise RuntimeError("The python `spython` package must be installed to " - "run singularity. Install with `pip install spython`") + raise RuntimeError( + "The python `spython` package must be installed to " + "run singularity. Install with `pip install spython`" + ) return run_sorter_container( container_image=container_image, diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py new file mode 100644 index 0000000000..8dbb1b20f6 --- /dev/null +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -0,0 +1,144 @@ +import os +import pytest +from pathlib import Path +import shutil +import platform +from spikeinterface import generate_ground_truth_recording +from spikeinterface.sorters.utils import has_spython, has_docker_python +from spikeinterface.sorters import run_sorter +import subprocess +import sys +import copy + + +def _monkeypatch_return_false(): + return False + + +class TestRunersorterDependencyChecks: + """ + This class performs tests to check whether expected + dependency checks prior to sorting are run. The + run_sorter function should raise an error if: + - singularity is not installed + - spython is not installed (python package) + - docker is not installed + - docker is not installed (python package) + when running singularity / docker respectively. + + Two separate checks should be run. First, that the + relevant `has_` function (indicating if the dependency + is installed) is working. Unfortunately it is not possible to + easily test this core singularity and docker installs, so this is not done. + `uninstall_python_dependency()` allows a test to check if the + `has_spython()` and `has_docker_dependency()` return `False` as expected + when these python modules are not installed. + + Second, the `run_sorters()` function should return the appropriate error + when these functions return that the dependency is not available. This is + easier to test as these `has_` reporting functions can be + monkeypatched to return False at runtime. This is done for these 4 + dependency checks, and tests check the expected error is raised. + + Notes + ---- + `has_nvidia()` and `has_docker_nvidia_installed()` are not tested + as these are complex GPU-related dependencies which are difficult to mock. + """ + + @pytest.fixture(scope="function") + def uninstall_python_dependency(self, request): + """ + This python fixture mocks python modules not been importable + by setting the relevant `sys.modules` dict entry to `None`. + It uses `yeild` so that the function can tear-down the test + (even if it failed) and replace the patched `sys.module` entry. + + This function uses an `indirect` parameterisation, meaning the + `request.param` is passed to the fixture at the start of the + test function. This is used to reuse code for nearly identical + `spython` and `docker` python dependency tests. + """ + dep_name = request.param + assert dep_name in ["spython", "docker"] + + try: + if dep_name == "spython": + import spython + else: + import docker + dependency_installed = True + except: + dependency_installed = False + + if dependency_installed: + copy_import = sys.modules[dep_name] + sys.modules[dep_name] = None + yield + if dependency_installed: + sys.modules[dep_name] = copy_import + + @pytest.fixture(scope="session") + def recording(self): + """ + Make a small recording to have something to pass to the sorter. + """ + recording, _ = generate_ground_truth_recording(durations=[10]) + return recording + + @pytest.mark.skipif(platform.system() != "Linux", reason="spython install only for Linux.") + @pytest.mark.parametrize("uninstall_python_dependency", ["spython"], indirect=True) + def test_has_spython(self, recording, uninstall_python_dependency): + """ + Test the `has_spython()` function, see class docstring and + `uninstall_python_dependency()` for details. + """ + assert has_spython() is False + + @pytest.mark.parametrize("uninstall_python_dependency", ["docker"], indirect=True) + def test_has_docker_python(self, recording, uninstall_python_dependency): + """ + Test the `has_docker_python()` function, see class docstring and + `uninstall_python_dependency()` for details. + """ + assert has_docker_python() is False + + @pytest.mark.parametrize("dependency", ["singularity", "spython"]) + def test_has_singularity_and_spython(self, recording, monkeypatch, dependency): + """ + When running a sorting, if singularity dependencies (singularity + itself or the `spython` package`) are not installed, an error is raised. + Beacause it is hard to actually uninstall these dependencies, the + `has_` functions that let `run_sorter` know if the dependency + are installed are monkeypatched. This is done so at runtime these always + return False. Then, test the expected error is raised when the dependency + is not found. + """ + test_func = f"has_{dependency}" + + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, singularity_image=True) + + if dependency == "spython": + assert "The python `spython` package must be installed" in str(e) + else: + assert "Singularity is not installed." in str(e) + + @pytest.mark.parametrize("dependency", ["docker", "docker_python"]) + def test_has_docker_and_docker_python(self, recording, monkeypatch, dependency): + """ + See `test_has_singularity_and_spython()` for details. This test + is almost identical, but with some key changes for Docker. + """ + test_func = f"has_{dependency}" + + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) + + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, docker_image=True) + + if dependency == "docker_python": + assert "The python `docker` package must be installed" in str(e) + else: + assert "Docker is not installed." in str(e) From 78ccc2719676b238dbd92d2ad5384786ca0724e0 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:24:29 +0100 Subject: [PATCH 05/29] Remove unnecessary non-relative import. --- src/spikeinterface/sorters/runsorter.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 5b2e80b83d..c16435cdb5 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -19,9 +19,7 @@ from ..core import BaseRecording, NumpySorting, load_extractor from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict - -# full import required for monkeypatch testing. -from spikeinterface.sorters.utils import ( +from .utils import ( SpikeSortingError, has_nvidia, has_docker, From f1438c4ce20bbd7ae3c910b793f92ebb4d723253 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:27:03 +0100 Subject: [PATCH 06/29] Fix some string formatting, add docstring to monkeypatch function. --- src/spikeinterface/sorters/runsorter.py | 6 ++---- .../sorters/tests/test_runsorter_dependency_checks.py | 4 ++++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index c16435cdb5..f9994dd38d 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -181,13 +181,11 @@ def run_sorter( if not has_docker(): raise RuntimeError( - "Docker is not installed. Install docker " "on this machine to run sorting with docker." + "Docker is not installed. Install docker on this machine to run sorting with docker." ) if not has_docker_python(): - raise RuntimeError( - "The python `docker` package must be installed. " "Install with `pip install docker`" - ) + raise RuntimeError("The python `docker` package must be installed. Install with `pip install docker`") else: mode = "singularity" diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index 8dbb1b20f6..c81593b7db 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -12,6 +12,10 @@ def _monkeypatch_return_false(): + """ + A function to monkeypatch the `has_` functions, + ensuring the always return `False` at runtime. + """ return False From fd4406e0826f80329614e3b59388e9640c00fe3e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 20:27:36 +0000 Subject: [PATCH 07/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/utils/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/utils/__init__.py b/src/spikeinterface/sorters/utils/__init__.py index 7f6f3089d4..62317be6f2 100644 --- a/src/spikeinterface/sorters/utils/__init__.py +++ b/src/spikeinterface/sorters/utils/__init__.py @@ -1,2 +1,14 @@ from .shellscript import ShellScript -from .misc import SpikeSortingError, get_git_commit, has_nvidia, get_matlab_shell_name, get_bash_path, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies +from .misc import ( + SpikeSortingError, + get_git_commit, + has_nvidia, + get_matlab_shell_name, + get_bash_path, + has_docker, + has_docker_python, + has_singularity, + has_spython, + has_docker_nvidia_installed, + get_nvidia_docker_dependecies, +) From 7af611ba289e220c4bf36f4b62ae26efe94f93b1 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:42:26 +0100 Subject: [PATCH 08/29] Mock all has functions to ensure tests do not depend on actual dependencies. --- .../tests/test_runsorter_dependency_checks.py | 58 +++++++++++++------ 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index c81593b7db..a248033089 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -4,7 +4,7 @@ import shutil import platform from spikeinterface import generate_ground_truth_recording -from spikeinterface.sorters.utils import has_spython, has_docker_python +from spikeinterface.sorters.utils import has_spython, has_docker_python, has_docker, has_singularity from spikeinterface.sorters import run_sorter import subprocess import sys @@ -19,6 +19,10 @@ def _monkeypatch_return_false(): return False +def _monkeypatch_return_true(): + return True + + class TestRunersorterDependencyChecks: """ This class performs tests to check whether expected @@ -91,6 +95,7 @@ def recording(self): return recording @pytest.mark.skipif(platform.system() != "Linux", reason="spython install only for Linux.") + @pytest.mark.skipif(not has_singularity(), reason="singularity required for this test.") @pytest.mark.parametrize("uninstall_python_dependency", ["spython"], indirect=True) def test_has_spython(self, recording, uninstall_python_dependency): """ @@ -100,6 +105,7 @@ def test_has_spython(self, recording, uninstall_python_dependency): assert has_spython() is False @pytest.mark.parametrize("uninstall_python_dependency", ["docker"], indirect=True) + @pytest.mark.skipif(not has_docker(), reason="docker required for this test.") def test_has_docker_python(self, recording, uninstall_python_dependency): """ Test the `has_docker_python()` function, see class docstring and @@ -107,8 +113,7 @@ def test_has_docker_python(self, recording, uninstall_python_dependency): """ assert has_docker_python() is False - @pytest.mark.parametrize("dependency", ["singularity", "spython"]) - def test_has_singularity_and_spython(self, recording, monkeypatch, dependency): + def test_no_singularity_error_raised(self, recording, monkeypatch): """ When running a sorting, if singularity dependencies (singularity itself or the `spython` package`) are not installed, an error is raised. @@ -118,31 +123,46 @@ def test_has_singularity_and_spython(self, recording, monkeypatch, dependency): return False. Then, test the expected error is raised when the dependency is not found. """ - test_func = f"has_{dependency}" + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_singularity", _monkeypatch_return_false) - monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) with pytest.raises(RuntimeError) as e: run_sorter("kilosort2_5", recording, singularity_image=True) - if dependency == "spython": - assert "The python `spython` package must be installed" in str(e) - else: - assert "Singularity is not installed." in str(e) + assert "Singularity is not installed." in str(e) - @pytest.mark.parametrize("dependency", ["docker", "docker_python"]) - def test_has_docker_and_docker_python(self, recording, monkeypatch, dependency): + def test_no_spython_error_raised(self, recording, monkeypatch): """ - See `test_has_singularity_and_spython()` for details. This test - is almost identical, but with some key changes for Docker. + See `test_no_singularity_error_raised()`. """ - test_func = f"has_{dependency}" + # make sure singularity test returns true as that comes first + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_singularity", _monkeypatch_return_true) + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_spython", _monkeypatch_return_false) + + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, singularity_image=True) + + assert "The python `spython` package must be installed" in str(e) - monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) + def test_no_docker_error_raised(self, recording, monkeypatch): + """ + See `test_no_singularity_error_raised()`. + """ + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_docker", _monkeypatch_return_false) + + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, docker_image=True) + + assert "Docker is not installed." in str(e) + + def test_as_no_docker_python_error_raised(self, recording, monkeypatch): + """ + See `test_no_singularity_error_raised()`. + """ + # make sure docker test returns true as that comes first + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_docker", _monkeypatch_return_true) + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_docker_python", _monkeypatch_return_false) with pytest.raises(RuntimeError) as e: run_sorter("kilosort2_5", recording, docker_image=True) - if dependency == "docker_python": - assert "The python `docker` package must be installed" in str(e) - else: - assert "Docker is not installed." in str(e) + assert "The python `docker` package must be installed" in str(e) From 0c0b1f908d8e356b9a58cacd4524ace871ff93b3 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:43:10 +0100 Subject: [PATCH 09/29] Remove unecessary skips. --- .../sorters/tests/test_runsorter_dependency_checks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index a248033089..741fe4ae0e 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -95,7 +95,6 @@ def recording(self): return recording @pytest.mark.skipif(platform.system() != "Linux", reason="spython install only for Linux.") - @pytest.mark.skipif(not has_singularity(), reason="singularity required for this test.") @pytest.mark.parametrize("uninstall_python_dependency", ["spython"], indirect=True) def test_has_spython(self, recording, uninstall_python_dependency): """ @@ -105,7 +104,6 @@ def test_has_spython(self, recording, uninstall_python_dependency): assert has_spython() is False @pytest.mark.parametrize("uninstall_python_dependency", ["docker"], indirect=True) - @pytest.mark.skipif(not has_docker(), reason="docker required for this test.") def test_has_docker_python(self, recording, uninstall_python_dependency): """ Test the `has_docker_python()` function, see class docstring and From 1be1dbd39a339ff56c0803ff7a59e5650d95b781 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 13 Jun 2024 09:04:04 +0100 Subject: [PATCH 10/29] Update docstrings. --- .../sorters/tests/test_runsorter_dependency_checks.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index 741fe4ae0e..c4beaba072 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -20,14 +20,18 @@ def _monkeypatch_return_false(): def _monkeypatch_return_true(): + """ + Monkeypatch for some `has_` functions to + return `True` so functions that are later in the + `runsorter` code can be checked. + """ return True class TestRunersorterDependencyChecks: """ - This class performs tests to check whether expected - dependency checks prior to sorting are run. The - run_sorter function should raise an error if: + This class tests whether expected dependency checks prior to sorting are run. + The run_sorter function should raise an error if: - singularity is not installed - spython is not installed (python package) - docker is not installed From 00663080b03f7933d37ba4ff2ee32e3402aa200e Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 13 Jun 2024 09:10:30 +0100 Subject: [PATCH 11/29] Swap return bool for to match function name. --- src/spikeinterface/sorters/runsorter.py | 2 +- src/spikeinterface/sorters/utils/misc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index f9994dd38d..80608f8973 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -494,7 +494,7 @@ def run_sorter_container( assert has_nvidia(), "The container requires a NVIDIA GPU capability, but it is not available" extra_kwargs["container_requires_gpu"] = True - if platform.system() == "Linux" and has_docker_nvidia_installed(): + if platform.system() == "Linux" and not has_docker_nvidia_installed(): warn( f"nvidia-required but none of \n{get_nvidia_docker_dependecies()}\n were found. " f"This may result in an error being raised during sorting. Try " diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 66744fbab1..1e01b9c052 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -119,7 +119,7 @@ def has_docker_nvidia_installed(): has_dep = [] for dep in all_dependencies: has_dep.append(_run_subprocess_silently(f"{dep} --version").returncode == 0) - return not any(has_dep) + return any(has_dep) def get_nvidia_docker_dependecies(): From 9664f69c4bcdd24e20584f601bcbd6a9ae79e174 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jun 2024 11:49:32 +0200 Subject: [PATCH 12/29] Apply suggestions from code review Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- .../sorters/tests/test_runsorter_dependency_checks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index c4beaba072..83d6ec3161 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -13,7 +13,7 @@ def _monkeypatch_return_false(): """ - A function to monkeypatch the `has_` functions, + A function to monkeypatch the `has_` functions, ensuring the always return `False` at runtime. """ return False @@ -61,12 +61,12 @@ class TestRunersorterDependencyChecks: @pytest.fixture(scope="function") def uninstall_python_dependency(self, request): """ - This python fixture mocks python modules not been importable + This python fixture mocks python modules not being importable by setting the relevant `sys.modules` dict entry to `None`. - It uses `yeild` so that the function can tear-down the test + It uses `yield` so that the function can tear-down the test (even if it failed) and replace the patched `sys.module` entry. - This function uses an `indirect` parameterisation, meaning the + This function uses an `indirect` parameterization, meaning the `request.param` is passed to the fixture at the start of the test function. This is used to reuse code for nearly identical `spython` and `docker` python dependency tests. From 543cc8f2a67719e4ae8b5b64a198a6c7256406e4 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Wed, 19 Jun 2024 18:12:31 +0100 Subject: [PATCH 13/29] Add apptainer case to 'has_singularity()' Co-authored-by: Alessio Buccino --- src/spikeinterface/sorters/utils/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 1e01b9c052..82480ffe0a 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -96,7 +96,7 @@ def has_docker(): def has_singularity(): - return _run_subprocess_silently("singularity --version").returncode == 0 + return _run_subprocess_silently("singularity --version").returncode == 0 or _run_subprocess_silently("apptainer --version").returncode == 0 def has_docker_nvidia_installed(): From dceb08070af9954b25c99c82ed2df314ef924aa7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jun 2024 17:12:51 +0000 Subject: [PATCH 14/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/utils/misc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 82480ffe0a..9c8c3bba89 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -96,7 +96,10 @@ def has_docker(): def has_singularity(): - return _run_subprocess_silently("singularity --version").returncode == 0 or _run_subprocess_silently("apptainer --version").returncode == 0 + return ( + _run_subprocess_silently("singularity --version").returncode == 0 + or _run_subprocess_silently("apptainer --version").returncode == 0 + ) def has_docker_nvidia_installed(): From 617649569e147f8a530d6cfd0c0637857481e367 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 20 Jun 2024 13:14:37 -0600 Subject: [PATCH 15/29] improve error log to json in run_sorter --- src/spikeinterface/sorters/basesorter.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 8c52626703..799444ddbd 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -262,7 +262,12 @@ def run_from_folder(cls, output_folder, raise_error, verbose): has_error = True run_time = None log["error"] = True - log["error_trace"] = traceback.format_exc() + error_log_to_display = traceback.format_exc() + trace_lines = error_log_to_display.strip().split("\n") + error_to_json = ["Traceback (most recent call last):"] + [ + f" {line}" if not line.startswith(" ") else line for line in trace_lines[1:] + ] + log["error_trace"] = error_to_json log["error"] = has_error log["run_time"] = run_time @@ -290,7 +295,7 @@ def run_from_folder(cls, output_folder, raise_error, verbose): if has_error and raise_error: raise SpikeSortingError( - f"Spike sorting error trace:\n{log['error_trace']}\n" + f"Spike sorting error trace:\n{error_log_to_display}\n" f"Spike sorting failed. You can inspect the runtime trace in {output_folder}/spikeinterface_log.json." ) From a166e5a3d419c49aa6afc69f0e2f98ea7eb9d0c3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 26 Jun 2024 15:33:51 -0600 Subject: [PATCH 16/29] add recording iterator --- src/spikeinterface/core/core_tools.py | 58 +++++++++++++++++-- src/spikeinterface/sorters/container_tools.py | 11 +--- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index f3d8b3df7f..3fe4939524 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -1,6 +1,6 @@ from __future__ import annotations from pathlib import Path, WindowsPath -from typing import Union +from typing import Union, Generator import os import sys import datetime @@ -8,6 +8,7 @@ from copy import deepcopy import importlib from math import prod +from collections import namedtuple import numpy as np @@ -183,6 +184,50 @@ def is_dict_extractor(d: dict) -> bool: return is_extractor +recording_dict_element = namedtuple(typename="recording_dict_element", field_names=["value", "name", "access_path"]) + + +def recording_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: + """ + Iterator for recursive traversal of a dictionary. + This function explores the dictionary recursively and yields the path to each value along with the value itself. + + By path here we mean the keys that lead to the value in the dictionary: + e.g. for the dictionary {'a': {'b': 1}}, the path to the value 1 is ('a', 'b'). + + See `BaseExtractor.to_dict()` for a description of `extractor_dict` structure. + + Parameters + ---------- + extractor_dict : dict + Input dictionary + + Yields + ------ + recording_dict_element + Named tuple containing the value, the name, and the access_path to the value in the dictionary. + + """ + + def _recording_dict_iterator(dict_list_or_value, access_path=(), name=""): + if isinstance(dict_list_or_value, dict): + for k, v in dict_list_or_value.items(): + yield from _recording_dict_iterator(v, access_path + (k,), name=k) + elif isinstance(dict_list_or_value, list): + for i, v in enumerate(dict_list_or_value): + yield from _recording_dict_iterator( + v, access_path + (i,), name=name + ) # Propagate name of list to children + else: + yield recording_dict_element( + value=dict_list_or_value, + name=name, + access_path=access_path, + ) + + yield from _recording_dict_iterator(extractor_dict) + + def recursive_path_modifier(d, func, target="path", copy=True) -> dict: """ Generic function for recursive modification of paths in an extractor dict. @@ -250,15 +295,16 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: raise ValueError(f"{k} key for path must be str or list[str]") -def _get_paths_list(d): +def _get_paths_list(d: dict) -> list[str | Path]: # this explore a dict and get all paths flatten in a list # the trick is to use a closure func called by recursive_path_modifier() - path_list = [] - def append_to_path(p): - path_list.append(p) + element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) + path_list = [e.value for e in recording_dict_iterator(d) if element_is_path(e)] + + # if check_if_exists: TODO: Enable this once container_tools test uses proper mocks + # path_list = [p for p in path_list if Path(p).exists()] - recursive_path_modifier(d, append_to_path, target="path", copy=True) return path_list diff --git a/src/spikeinterface/sorters/container_tools.py b/src/spikeinterface/sorters/container_tools.py index 60eb080ae5..8e03090eaf 100644 --- a/src/spikeinterface/sorters/container_tools.py +++ b/src/spikeinterface/sorters/container_tools.py @@ -9,19 +9,14 @@ # TODO move this inside functions -from spikeinterface.core.core_tools import recursive_path_modifier +from spikeinterface.core.core_tools import recursive_path_modifier, _get_paths_list def find_recording_folders(d): """Finds all recording folders 'paths' in a dict""" - folders_to_mount = [] - def append_parent_folder(p): - p = Path(p) - folders_to_mount.append(p.resolve().absolute().parent) - return p - - _ = recursive_path_modifier(d, append_parent_folder, target="path", copy=True) + path_list = _get_paths_list(d=d) + folders_to_mount = [Path(p).resolve().parent for p in path_list] try: # this will fail if on different drives (Windows) base_folders_to_mount = [Path(os.path.commonpath(folders_to_mount))] From 27a7c9a96c2e8f008109c99d8dd90ac52ac5fd3e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 26 Jun 2024 16:58:39 -0600 Subject: [PATCH 17/29] add and fix tests --- src/spikeinterface/core/core_tools.py | 83 ++++++++-- .../core/tests/test_core_tools.py | 153 ++++++++++++------ 2 files changed, 170 insertions(+), 66 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 3fe4939524..9e90b56c8d 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -187,7 +187,7 @@ def is_dict_extractor(d: dict) -> bool: recording_dict_element = namedtuple(typename="recording_dict_element", field_names=["value", "name", "access_path"]) -def recording_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: +def extractor_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: """ Iterator for recursive traversal of a dictionary. This function explores the dictionary recursively and yields the path to each value along with the value itself. @@ -209,13 +209,13 @@ def recording_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_el """ - def _recording_dict_iterator(dict_list_or_value, access_path=(), name=""): + def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""): if isinstance(dict_list_or_value, dict): for k, v in dict_list_or_value.items(): - yield from _recording_dict_iterator(v, access_path + (k,), name=k) + yield from _extractor_dict_iterator(v, access_path + (k,), name=k) elif isinstance(dict_list_or_value, list): for i, v in enumerate(dict_list_or_value): - yield from _recording_dict_iterator( + yield from _extractor_dict_iterator( v, access_path + (i,), name=name ) # Propagate name of list to children else: @@ -225,7 +225,32 @@ def _recording_dict_iterator(dict_list_or_value, access_path=(), name=""): access_path=access_path, ) - yield from _recording_dict_iterator(extractor_dict) + yield from _extractor_dict_iterator(extractor_dict) + + +def set_value_in_recording_dict(extractor_dict: dict, access_path: tuple, new_value): + """ + In place modification of a value in a nested dictionary given its access path. + + Parameters + ---------- + extractor_dict : dict + The dictionary to modify + access_path : tuple + The path to the value in the dictionary + new_value : object + The new value to set + + Returns + ------- + dict + The modified dictionary + """ + + current = extractor_dict + for key in access_path[:-1]: + current = current[key] + current[access_path[-1]] = new_value def recursive_path_modifier(d, func, target="path", copy=True) -> dict: @@ -295,12 +320,13 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: raise ValueError(f"{k} key for path must be str or list[str]") -def _get_paths_list(d: dict) -> list[str | Path]: - # this explore a dict and get all paths flatten in a list - # the trick is to use a closure func called by recursive_path_modifier() +# This is the current definition that an element in a recording_dict is a path +# This is shared across a couple of definition so it is here for DNRY +element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) + - element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) - path_list = [e.value for e in recording_dict_iterator(d) if element_is_path(e)] +def _get_paths_list(d: dict) -> list[str | Path]: + path_list = [e.value for e in extractor_dict_iterator(d) if element_is_path(e)] # if check_if_exists: TODO: Enable this once container_tools test uses proper mocks # path_list = [p for p in path_list if Path(p).exists()] @@ -364,7 +390,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool: return len(not_possible) == 0 -def make_paths_relative(input_dict, relative_folder) -> dict: +def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict: """ Recursively transform a dict describing an BaseExtractor to make every path relative to a folder. @@ -380,9 +406,22 @@ def make_paths_relative(input_dict, relative_folder) -> dict: output_dict: dict A copy of the input dict with modified paths. """ + relative_folder = Path(relative_folder).resolve().absolute() - func = lambda p: _relative_to(p, relative_folder) - output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True) + + path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)] + # Only paths that exist are made relative + path_elements_in_dict = [e for e in path_elements_in_dict if Path(e.value).exists()] + + output_dict = deepcopy(input_dict) + for element in path_elements_in_dict: + new_value = _relative_to(element.value, relative_folder) + set_value_in_recording_dict( + extractor_dict=output_dict, + access_path=element.access_path, + new_value=new_value, + ) + return output_dict @@ -405,12 +444,28 @@ def make_paths_absolute(input_dict, base_folder): base_folder = Path(base_folder) # use as_posix instead of str to make the path unix like even on window func = lambda p: (base_folder / p).resolve().absolute().as_posix() - output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True) + + path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)] + output_dict = deepcopy(input_dict) + + output_dict = deepcopy(input_dict) + for element in path_elements_in_dict: + absolute_path = (base_folder / element.value).resolve() + if Path(absolute_path).exists(): + new_value = absolute_path.as_posix() # Not so sure about this, Sam + set_value_in_recording_dict( + extractor_dict=output_dict, + access_path=element.access_path, + new_value=new_value, + ) + return output_dict def recursive_key_finder(d, key): # Find all values for a key on a dictionary, even if nested + # TODO refactor to use extractor_dict_iterator + for k, v in d.items(): if isinstance(v, dict): yield from recursive_key_finder(v, key) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 8e00dcb779..043e0cabf3 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -51,14 +51,9 @@ def test_path_utils_functions(): assert d2["kwargs"]["path"].startswith("/yop") assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") - d3 = make_paths_relative(d, Path("/yep")) - assert d3["kwargs"]["path"] == "sub/path1" - assert d3["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" - - d4 = make_paths_absolute(d3, "/yop") - assert d4["kwargs"]["path"].startswith("/yop") - assert d4["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") +@pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") +def test_relative_path_on_windows(): if platform.system() == "Windows": # test for windows Path d = { @@ -74,57 +69,111 @@ def test_path_utils_functions(): } } - d2 = make_paths_relative(d, "c:\\yep") - # the str be must unix like path even on windows for more portability - assert d2["kwargs"]["path"] == "sub/path1" - assert d2["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" - # same drive assert check_paths_relative(d, r"c:\yep") # not the same drive assert not check_paths_relative(d, r"d:\yep") - d = { - "kwargs": { - "path": r"\\host\share\yep\sub\path1", - } - } - # UNC cannot be relative to d: drive - assert not check_paths_relative(d, r"d:\yep") - # UNC can be relative to the same UNC - assert check_paths_relative(d, r"\\host\share") - - def test_convert_string_to_bytes(): - # Test SI prefixes - assert convert_string_to_bytes("1k") == 1000 - assert convert_string_to_bytes("1M") == 1000000 - assert convert_string_to_bytes("1G") == 1000000000 - assert convert_string_to_bytes("1T") == 1000000000000 - assert convert_string_to_bytes("1P") == 1000000000000000 - # Test IEC prefixes - assert convert_string_to_bytes("1Ki") == 1024 - assert convert_string_to_bytes("1Mi") == 1048576 - assert convert_string_to_bytes("1Gi") == 1073741824 - assert convert_string_to_bytes("1Ti") == 1099511627776 - assert convert_string_to_bytes("1Pi") == 1125899906842624 - # Test mixed values - assert convert_string_to_bytes("1.5k") == 1500 - assert convert_string_to_bytes("2.5M") == 2500000 - assert convert_string_to_bytes("0.5G") == 500000000 - assert convert_string_to_bytes("1.2T") == 1200000000000 - assert convert_string_to_bytes("1.5Pi") == 1688849860263936 - # Test zero values - assert convert_string_to_bytes("0k") == 0 - assert convert_string_to_bytes("0Ki") == 0 - # Test invalid inputs (should raise assertion error) - with pytest.raises(AssertionError) as e: - convert_string_to_bytes("1Z") - assert str(e.value) == "Unknown suffix: Z" - - with pytest.raises(AssertionError) as e: - convert_string_to_bytes("1Xi") - assert str(e.value) == "Unknown suffix: Xi" +@pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") +def test_universal_naming_convention(): + d = { + "kwargs": { + "path": r"\\host\share\yep\sub\path1", + } + } + # UNC cannot be relative to d: drive + assert not check_paths_relative(d, r"d:\yep") + + # UNC can be relative to the same UNC + assert check_paths_relative(d, r"\\host\share") + + +def test_make_paths_relative(tmp_path): + + path_1 = tmp_path / "sub" / "path1" + path_2 = tmp_path / "sub" / "path2" + + # Create the objects in the path + path_1.mkdir(parents=True, exist_ok=True) + path_2.mkdir(parents=True, exist_ok=True) + extractor_dict = { + "kwargs": { + "path": str(path_1), # Note this is different in windows and posix + "electrical_series_path": "/acquisition/timeseries", # non-existent path-like objects should not be modified + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": str(path_2)}, + }, + } + } + modified_extractor_dict = make_paths_relative(extractor_dict, tmp_path) + assert modified_extractor_dict["kwargs"]["path"] == "sub/path1" + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" + assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" + + +def test_make_paths_absolute(tmp_path): + + path_1 = tmp_path / "sub" / "path1" + path_2 = tmp_path / "sub" / "path2" + + path_1.mkdir(parents=True, exist_ok=True) + path_2.mkdir(parents=True, exist_ok=True) + + extractor_dict = { + "kwargs": { + "path": "sub/path1", + "electrical_series_path": "/acquisition/timeseries", # non-existent path-like objects should not be modified + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": "sub/path2"}, + }, + } + } + + modified_extractor_dict = make_paths_absolute(extractor_dict, tmp_path) + assert modified_extractor_dict["kwargs"]["path"].startswith(str(tmp_path)) + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"].startswith(str(tmp_path)) + assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" + + +def test_convert_string_to_bytes(): + # Test SI prefixes + assert convert_string_to_bytes("1k") == 1000 + assert convert_string_to_bytes("1M") == 1000000 + assert convert_string_to_bytes("1G") == 1000000000 + assert convert_string_to_bytes("1T") == 1000000000000 + assert convert_string_to_bytes("1P") == 1000000000000000 + # Test IEC prefixes + assert convert_string_to_bytes("1Ki") == 1024 + assert convert_string_to_bytes("1Mi") == 1048576 + assert convert_string_to_bytes("1Gi") == 1073741824 + assert convert_string_to_bytes("1Ti") == 1099511627776 + assert convert_string_to_bytes("1Pi") == 1125899906842624 + # Test mixed values + assert convert_string_to_bytes("1.5k") == 1500 + assert convert_string_to_bytes("2.5M") == 2500000 + assert convert_string_to_bytes("0.5G") == 500000000 + assert convert_string_to_bytes("1.2T") == 1200000000000 + assert convert_string_to_bytes("1.5Pi") == 1688849860263936 + # Test zero values + assert convert_string_to_bytes("0k") == 0 + assert convert_string_to_bytes("0Ki") == 0 + # Test invalid inputs (should raise assertion error) + with pytest.raises(AssertionError) as e: + convert_string_to_bytes("1Z") + assert str(e.value) == "Unknown suffix: Z" + + with pytest.raises(AssertionError) as e: + convert_string_to_bytes("1Xi") + assert str(e.value) == "Unknown suffix: Xi" def test_normal_pdf() -> None: From b3b85b2fe5670217d80c4adec1a751d1e1d5d024 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 26 Jun 2024 17:21:45 -0600 Subject: [PATCH 18/29] naming --- src/spikeinterface/core/core_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 9e90b56c8d..d5480d6f00 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -228,7 +228,7 @@ def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""): yield from _extractor_dict_iterator(extractor_dict) -def set_value_in_recording_dict(extractor_dict: dict, access_path: tuple, new_value): +def set_value_in_extractor_dict(extractor_dict: dict, access_path: tuple, new_value): """ In place modification of a value in a nested dictionary given its access path. @@ -416,7 +416,7 @@ def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict: output_dict = deepcopy(input_dict) for element in path_elements_in_dict: new_value = _relative_to(element.value, relative_folder) - set_value_in_recording_dict( + set_value_in_extractor_dict( extractor_dict=output_dict, access_path=element.access_path, new_value=new_value, @@ -453,7 +453,7 @@ def make_paths_absolute(input_dict, base_folder): absolute_path = (base_folder / element.value).resolve() if Path(absolute_path).exists(): new_value = absolute_path.as_posix() # Not so sure about this, Sam - set_value_in_recording_dict( + set_value_in_extractor_dict( extractor_dict=output_dict, access_path=element.access_path, new_value=new_value, From d794c8220e9e2ed2431636e53aee9b7b8d6b998b Mon Sep 17 00:00:00 2001 From: h-mayorquin Date: Thu, 27 Jun 2024 00:39:58 -0600 Subject: [PATCH 19/29] windows test remove inner conditional --- .../core/tests/test_core_tools.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 043e0cabf3..ed13bd46fd 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -54,25 +54,24 @@ def test_path_utils_functions(): @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") def test_relative_path_on_windows(): - if platform.system() == "Windows": - # test for windows Path - d = { - "kwargs": { - "path": r"c:\yep\sub\path1", - "recording": { - "module": "mock_module", - "class": "mock_class", - "version": "1.2", - "annotations": {}, - "kwargs": {"path": r"c:\yep\sub\path2"}, - }, - } + + d = { + "kwargs": { + "path": r"c:\yep\sub\path1", + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": r"c:\yep\sub\path2"}, + }, } + } - # same drive - assert check_paths_relative(d, r"c:\yep") - # not the same drive - assert not check_paths_relative(d, r"d:\yep") + # same drive + assert check_paths_relative(d, r"c:\yep") + # not the same drive + assert not check_paths_relative(d, r"d:\yep") @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") @@ -139,8 +138,8 @@ def test_make_paths_absolute(tmp_path): } modified_extractor_dict = make_paths_absolute(extractor_dict, tmp_path) - assert modified_extractor_dict["kwargs"]["path"].startswith(str(tmp_path)) - assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"].startswith(str(tmp_path)) + assert modified_extractor_dict["kwargs"]["path"].startswith(str(tmp_path.as_posix())) + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"].startswith(str(tmp_path.as_posix())) assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" From c1e4eee519c289899f2650d98e6210d631ae42f2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Jun 2024 00:41:00 +0000 Subject: [PATCH 20/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_core_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index ed13bd46fd..724517577c 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -54,7 +54,7 @@ def test_path_utils_functions(): @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") def test_relative_path_on_windows(): - + d = { "kwargs": { "path": r"c:\yep\sub\path1", From d1d65f6ca6338ac2dd8d6f9c99ee657f0db76d21 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 11:58:23 +0100 Subject: [PATCH 21/29] estimate_sparsity arg ordering --- src/spikeinterface/core/sortinganalyzer.py | 2 +- src/spikeinterface/core/sparsity.py | 6 +++--- src/spikeinterface/core/tests/test_sparsity.py | 4 ++-- .../postprocessing/tests/common_extension_tests.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 53e060262b..62b7f9e7c0 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -127,7 +127,7 @@ def create_sorting_analyzer( recording.channel_ids, sparsity.channel_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" elif sparse: - sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs) + sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs) else: sparsity = None diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index cefd7bd950..1cd7822f99 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -539,8 +539,8 @@ def compute_sparsity( def estimate_sparsity( - recording: BaseRecording, sorting: BaseSorting, + recording: BaseRecording, num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, @@ -563,10 +563,10 @@ def estimate_sparsity( Parameters ---------- - recording: BaseRecording - The recording sorting: BaseSorting The sorting + recording: BaseRecording + The recording num_spikes_for_sparsity: int, default: 100 How many spikes per units to compute the sparsity ms_before: float, default: 1.0 diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 98d033d8ea..a192d90502 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -166,8 +166,8 @@ def test_estimate_sparsity(): # small radius should give a very sparse = one channel per unit sparsity = estimate_sparsity( - recording, sorting, + recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, @@ -182,8 +182,8 @@ def test_estimate_sparsity(): # best_channel : the mask should exactly 3 channels per units sparsity = estimate_sparsity( - recording, sorting, + recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index bf462a9466..8c46fa5e24 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -79,7 +79,7 @@ class AnalyzerExtensionCommonTestSuite: def setUpClass(cls): cls.recording, cls.sorting = get_dataset() # sparsity is computed once for all cases to save processing time and force a small radius - cls.sparsity = estimate_sparsity(cls.recording, cls.sorting, method="radius", radius_um=20) + cls.sparsity = estimate_sparsity(cls.sorting, cls.recording, method="radius", radius_um=20) @property def extension_name(self): From 2cc719986e5d6fceb9ea828206d7cf1d9a3fef9a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 27 Jun 2024 08:11:55 -0600 Subject: [PATCH 22/29] @alejo91 suggestion --- src/spikeinterface/core/core_tools.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index d5480d6f00..066ab58d8c 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -184,10 +184,10 @@ def is_dict_extractor(d: dict) -> bool: return is_extractor -recording_dict_element = namedtuple(typename="recording_dict_element", field_names=["value", "name", "access_path"]) +extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"]) -def extractor_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: +def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]: """ Iterator for recursive traversal of a dictionary. This function explores the dictionary recursively and yields the path to each value along with the value itself. @@ -204,7 +204,7 @@ def extractor_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_el Yields ------ - recording_dict_element + extractor_dict_element Named tuple containing the value, the name, and the access_path to the value in the dictionary. """ @@ -219,7 +219,7 @@ def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""): v, access_path + (i,), name=name ) # Propagate name of list to children else: - yield recording_dict_element( + yield extractor_dict_element( value=dict_list_or_value, name=name, access_path=access_path, @@ -320,7 +320,7 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: raise ValueError(f"{k} key for path must be str or list[str]") -# This is the current definition that an element in a recording_dict is a path +# This is the current definition that an element in a extractor_dict is a path # This is shared across a couple of definition so it is here for DNRY element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) From 61060781eef87597461241aec077aac27baff69b Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:15:14 +0100 Subject: [PATCH 23/29] SpikeRetriever arg switch --- src/spikeinterface/core/node_pipeline.py | 16 +-- .../core/tests/test_node_pipeline.py | 4 +- .../tests/test_train_manual_curation.py | 120 ++++++++++++++++++ .../postprocessing/amplitude_scalings.py | 2 +- .../postprocessing/spike_amplitudes.py | 2 +- .../postprocessing/spike_locations.py | 2 +- 6 files changed, 133 insertions(+), 13 deletions(-) create mode 100644 src/spikeinterface/curation/tests/test_train_manual_curation.py diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1c0107d235..0722ede23f 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -152,29 +152,29 @@ class SpikeRetriever(PeakSource): * compute_spike_amplitudes() * compute_principal_components() + sorting : BaseSorting + The sorting object. recording : BaseRecording The recording object. - sorting: BaseSorting - The sorting object. - channel_from_template: bool, default: True + channel_from_template : bool, default: True If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. If False, the max channel is computed for each spike given a radius around the template max channel. - extremum_channel_inds: dict of int | None, default: None + extremum_channel_inds : dict of int | None, default: None The extremum channel index dict given from template. - radius_um: float, default: 50 + radius_um : float, default: 50 The radius to find the real max channel. Used only when channel_from_template=False - peak_sign: "neg" | "pos", default: "neg" + peak_sign : "neg" | "pos", default: "neg" Peak sign to find the max channel. Used only when channel_from_template=False - include_spikes_in_margin: bool, default False + include_spikes_in_margin : bool, default False If not None then spikes in margin are added and an extra filed in dtype is added """ def __init__( self, - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=None, radius_um=50, diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 03acc9fed1..8d788acbad 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -87,12 +87,12 @@ def test_run_node_pipeline(cache_folder_creation): peak_retriever = PeakRetriever(recording, peaks) # channel index is from template spike_retriever_T = SpikeRetriever( - recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) # channel index is per spike spike_retriever_S = SpikeRetriever( - recording, sorting, + recording, channel_from_template=False, extremum_channel_inds=extremum_channel_inds, radius_um=50, diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py new file mode 100644 index 0000000000..f0f9ff4d75 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -0,0 +1,120 @@ +import pytest +import pandas as pd +import os +import shutil + +from spikeinterface.curation.train_manual_curation import CurationModelTrainer, Objective, train_model + +# Sample data for testing +data = { + 'num_spikes': [1, 2, 3, 4, 5, 6], + 'firing_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'presence_ratio': [0.9, 0.8, 0.7, 0.6, 0.5, 0.4], + 'isi_violations_ratio': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06], + 'amplitude_cutoff': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'amplitude_median': [0.2, 0.3, 0.4, 0.5, 0.6, 0.7], + 'amplitude_cv_median': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'amplitude_cv_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_2': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_4': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_8': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'firing_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_ptp': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_std': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_mad': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'isolation_distance': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'l_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'd_prime': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'silhouette': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'nn_hit_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'nn_miss_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'peak_to_valley': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'peak_trough_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'half_width': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'repolarization_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'recovery_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'num_positive_peaks': [1, 2, 3, 4, 5, 6], + 'num_negative_peaks': [1, 2, 3, 4, 5, 6], + 'velocity_above': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'velocity_below': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'exp_decay': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'spread': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'is_noise': [0, 1, 0, 1, 0, 1], + 'is_sua': [1, 0, 1, 0, 1, 0], + 'majority_vote': ['good', 'bad', 'good', 'bad', 'good', 'bad'] +} + +df = pd.DataFrame(data) + +# Test initialization +def test_initialization(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + assert trainer.output_folder == '/tmp' + assert trainer.curator_column == 'num_spikes' + assert trainer.imputation_strategies is not None + assert trainer.scaling_techniques is not None + +# Test load_data_file +def test_load_data_file(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + df.to_csv('/tmp/test.csv', index=False) + trainer.load_data_file('/tmp/test.csv') + assert trainer.testing_metrics is not None + assert 0 in trainer.testing_metrics + +# Test process_test_data_for_classification +def test_process_test_data_for_classification(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + trainer.testing_metrics = {0: df} + trainer.process_test_data_for_classification() + assert trainer.noise_test is not None + assert trainer.sua_mua_test is not None + +# Test apply_scaling_imputation +def test_apply_scaling_imputation(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + X_train = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) + X_val = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) + y_train = df['is_noise'] + y_val = df['is_noise'] + result = trainer.apply_scaling_imputation('median', trainer.scaling_techniques[0][1], X_train, X_val, y_train, y_val) + assert result is not None + +# Test get_classifier_search_space +def test_get_classifier_search_space(): + from sklearn.linear_model import LogisticRegression + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + model, param_space = trainer.get_classifier_search_space(LogisticRegression) + assert model is not None + assert param_space is not None + +# Test Objective Enum +def test_objective_enum(): + assert Objective.Noise == Objective(1) + assert Objective.SUA == Objective(2) + assert str(Objective.Noise) == "Objective.Noise" + assert str(Objective.SUA) == "Objective.SUA" + +# Test train_model function +def test_train_model(monkeypatch): + output_folder = '/tmp/output' + os.makedirs(output_folder, exist_ok=True) + df.to_csv('/tmp/metrics.csv', index=False) + + def mock_load_and_preprocess_full(self, path): + self.testing_metrics = {0: df} + self.process_test_data_for_classification() + + monkeypatch.setattr(CurationModelTrainer, 'load_and_preprocess_full', mock_load_and_preprocess_full) + + trainer = train_model('/tmp/metrics.csv', output_folder, 'is_noise') + assert trainer is not None + assert trainer.testing_metrics is not None + assert 0 in trainer.testing_metrics + +# Clean up temporary files +@pytest.fixture(scope="module", autouse=True) +def cleanup(request): + def remove_tmp(): + shutil.rmtree('/tmp', ignore_errors=True) + request.addfinalizer(remove_tmp) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 2e544d086b..8ff9cc5666 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -170,8 +170,8 @@ def _get_pipeline_nodes(self): sparsity_mask = sparsity.mask spike_retriever_node = SpikeRetriever( - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices, include_spikes_in_margin=True, diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index aebfd1fd78..72cbcb651f 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -95,7 +95,7 @@ def _get_pipeline_nodes(self): peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) spike_retriever_node = SpikeRetriever( - recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channels_indices + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices ) spike_amplitudes_node = SpikeAmplitudeNode( recording, diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 52a91342b6..23301292e5 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -103,8 +103,8 @@ def _get_pipeline_nodes(self): ) retriever = SpikeRetriever( - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices, ) From 722c313382b6ac225a2c9119c676bc1bcab6e480 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:17:43 +0100 Subject: [PATCH 24/29] has_exceeding_spikes arg switch --- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/frameslicesorting.py | 2 +- src/spikeinterface/core/waveform_tools.py | 2 +- src/spikeinterface/curation/remove_excess_spikes.py | 2 +- .../curation/tests/test_remove_excess_spikes.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index fd68df9dda..d9a567dedf 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -197,7 +197,7 @@ def register_recording(self, recording, check_spike_frames=True): self.get_num_segments() == recording.get_num_segments() ), "The recording has a different number of segments than the sorting!" if check_spike_frames: - if has_exceeding_spikes(recording, self): + if has_exceeding_spikes(self, recording): warnings.warn( "Some spikes exceed the recording's duration! " "Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` " diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index ffd8af5fd8..f3ec449ab0 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -54,7 +54,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike assert ( start_frame <= parent_n_samples ), "`start_frame` should be smaller than the sortings' total number of samples." - if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): + if check_spike_frames and has_exceeding_spikes(parent_sorting, parent_sorting._recording): raise ValueError( "The sorting object has spikes whose times go beyond the recording duration." "This could indicate a bug in the sorter. " diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index befc49d034..4543074872 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -679,7 +679,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None return waveforms_by_units -def has_exceeding_spikes(recording, sorting) -> bool: +def has_exceeding_spikes(sorting, recording) -> bool: """ Check if the sorting objects has spikes exceeding the recording number of samples, for all segments diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 0ae7a59fc6..d1d6b7f3cb 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -102,7 +102,7 @@ def remove_excess_spikes(sorting, recording): sorting_without_excess_spikes : Sorting The sorting without any excess spikes. """ - if has_exceeding_spikes(recording=recording, sorting=sorting): + if has_exceeding_spikes(sorting=sorting, recording=recording): return RemoveExcessSpikesSorting(sorting=sorting, recording=recording) else: return sorting diff --git a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py index 69edbaba4c..141cc4c34e 100644 --- a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py +++ b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py @@ -39,10 +39,10 @@ def test_remove_excess_spikes(): labels.append(labels_segment) sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency=sampling_frequency) - assert has_exceeding_spikes(recording, sorting) + assert has_exceeding_spikes(sorting, recording) sorting_corrected = remove_excess_spikes(sorting, recording) - assert not has_exceeding_spikes(recording, sorting_corrected) + assert not has_exceeding_spikes(sorting_corrected, recording) for u in sorting.unit_ids: for segment_index in range(sorting.get_num_segments()): From d0968c4c941e290488848d14c6881c7a2cdf9c8c Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:19:24 +0100 Subject: [PATCH 25/29] removed accidental commit --- .../tests/test_train_manual_curation.py | 120 ------------------ 1 file changed, 120 deletions(-) delete mode 100644 src/spikeinterface/curation/tests/test_train_manual_curation.py diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py deleted file mode 100644 index f0f9ff4d75..0000000000 --- a/src/spikeinterface/curation/tests/test_train_manual_curation.py +++ /dev/null @@ -1,120 +0,0 @@ -import pytest -import pandas as pd -import os -import shutil - -from spikeinterface.curation.train_manual_curation import CurationModelTrainer, Objective, train_model - -# Sample data for testing -data = { - 'num_spikes': [1, 2, 3, 4, 5, 6], - 'firing_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'presence_ratio': [0.9, 0.8, 0.7, 0.6, 0.5, 0.4], - 'isi_violations_ratio': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06], - 'amplitude_cutoff': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'amplitude_median': [0.2, 0.3, 0.4, 0.5, 0.6, 0.7], - 'amplitude_cv_median': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'amplitude_cv_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_2': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_4': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_8': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'firing_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_ptp': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_std': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_mad': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'isolation_distance': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'l_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'd_prime': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'silhouette': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'nn_hit_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'nn_miss_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'peak_to_valley': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'peak_trough_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'half_width': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'repolarization_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'recovery_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'num_positive_peaks': [1, 2, 3, 4, 5, 6], - 'num_negative_peaks': [1, 2, 3, 4, 5, 6], - 'velocity_above': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'velocity_below': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'exp_decay': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'spread': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'is_noise': [0, 1, 0, 1, 0, 1], - 'is_sua': [1, 0, 1, 0, 1, 0], - 'majority_vote': ['good', 'bad', 'good', 'bad', 'good', 'bad'] -} - -df = pd.DataFrame(data) - -# Test initialization -def test_initialization(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - assert trainer.output_folder == '/tmp' - assert trainer.curator_column == 'num_spikes' - assert trainer.imputation_strategies is not None - assert trainer.scaling_techniques is not None - -# Test load_data_file -def test_load_data_file(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - df.to_csv('/tmp/test.csv', index=False) - trainer.load_data_file('/tmp/test.csv') - assert trainer.testing_metrics is not None - assert 0 in trainer.testing_metrics - -# Test process_test_data_for_classification -def test_process_test_data_for_classification(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - trainer.testing_metrics = {0: df} - trainer.process_test_data_for_classification() - assert trainer.noise_test is not None - assert trainer.sua_mua_test is not None - -# Test apply_scaling_imputation -def test_apply_scaling_imputation(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - X_train = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) - X_val = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) - y_train = df['is_noise'] - y_val = df['is_noise'] - result = trainer.apply_scaling_imputation('median', trainer.scaling_techniques[0][1], X_train, X_val, y_train, y_val) - assert result is not None - -# Test get_classifier_search_space -def test_get_classifier_search_space(): - from sklearn.linear_model import LogisticRegression - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - model, param_space = trainer.get_classifier_search_space(LogisticRegression) - assert model is not None - assert param_space is not None - -# Test Objective Enum -def test_objective_enum(): - assert Objective.Noise == Objective(1) - assert Objective.SUA == Objective(2) - assert str(Objective.Noise) == "Objective.Noise" - assert str(Objective.SUA) == "Objective.SUA" - -# Test train_model function -def test_train_model(monkeypatch): - output_folder = '/tmp/output' - os.makedirs(output_folder, exist_ok=True) - df.to_csv('/tmp/metrics.csv', index=False) - - def mock_load_and_preprocess_full(self, path): - self.testing_metrics = {0: df} - self.process_test_data_for_classification() - - monkeypatch.setattr(CurationModelTrainer, 'load_and_preprocess_full', mock_load_and_preprocess_full) - - trainer = train_model('/tmp/metrics.csv', output_folder, 'is_noise') - assert trainer is not None - assert trainer.testing_metrics is not None - assert 0 in trainer.testing_metrics - -# Clean up temporary files -@pytest.fixture(scope="module", autouse=True) -def cleanup(request): - def remove_tmp(): - shutil.rmtree('/tmp', ignore_errors=True) - request.addfinalizer(remove_tmp) From f687c2c2fe9b70a970cfd39d6dd7b134c15e065f Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:20:32 +0100 Subject: [PATCH 26/29] docs --- src/spikeinterface/core/waveform_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 4543074872..98380e955f 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -685,10 +685,10 @@ def has_exceeding_spikes(sorting, recording) -> bool: Parameters ---------- - recording : BaseRecording - The recording object sorting : BaseSorting The sorting object + recording : BaseRecording + The recording object Returns ------- From b8c8fa83ba8695545b420d135c92f5167d7d2de1 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:54:59 +0100 Subject: [PATCH 27/29] Missed one --- .../postprocessing/tests/common_extension_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index bb2f5aaafd..52dbaf23d4 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -73,7 +73,7 @@ class instance is used for each. In this case, we have to set self.__class__.recording, self.__class__.sorting = get_dataset() self.__class__.sparsity = estimate_sparsity( - self.__class__.recording, self.__class__.sorting, method="radius", radius_um=20 + self.__class__.sorting, self.__class__.recording, method="radius", radius_um=20 ) self.__class__.cache_folder = create_cache_folder From 3eee955a8da3989dda6cbd84b25c0eabc2222527 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 27 Jun 2024 09:01:15 -0600 Subject: [PATCH 28/29] make test skipif --- .../core/tests/test_core_tools.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 724517577c..7153991543 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -31,25 +31,25 @@ def test_add_suffix(): assert str(file_path_with_suffix) == expected_path +@pytest.mark.skipif(platform.system() == "Windows", reason="Runs on posix only") def test_path_utils_functions(): - if platform.system() != "Windows": - # posix path - d = { - "kwargs": { - "path": "/yep/sub/path1", - "recording": { - "module": "mock_module", - "class": "mock_class", - "version": "1.2", - "annotations": {}, - "kwargs": {"path": "/yep/sub/path2"}, - }, - } + # posix path + d = { + "kwargs": { + "path": "/yep/sub/path1", + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": "/yep/sub/path2"}, + }, } + } - d2 = recursive_path_modifier(d, lambda p: p.replace("/yep", "/yop")) - assert d2["kwargs"]["path"].startswith("/yop") - assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") + d2 = recursive_path_modifier(d, lambda p: p.replace("/yep", "/yop")) + assert d2["kwargs"]["path"].startswith("/yop") + assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") From d5ec1806bf41c27317f60e7c96cf71972400774b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 28 Jun 2024 16:58:30 -0400 Subject: [PATCH 29/29] get rid of waveform term --- src/spikeinterface/widgets/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index b94167d2b7..9566989d31 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -139,7 +139,7 @@ def check_extensions(sorting_analyzer, extensions): if not sorting_analyzer.has_extension(extension): raise_error = True error_msg += ( - f"The {extension} waveform extension is required for this widget. " + f"The {extension} sorting analyzer extension is required for this widget. " f"Run the `sorting_analyzer.compute('{extension}', ...)` to compute it.\n" ) if raise_error: