From 7364fb7c280c35f110b7995cafe87b3086c2e2d3 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 17 Nov 2023 13:53:28 -0500 Subject: [PATCH 1/8] update mountainsort5 sorter --- .../sorters/external/mountainsort5.py | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 0162e6c943..0957a2321a 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -1,17 +1,12 @@ from pathlib import Path -from tempfile import tempdir from packaging.version import parse from spikeinterface.preprocessing import bandpass_filter, whiten from spikeinterface.core.baserecording import BaseRecording from ..basesorter import BaseSorter -from spikeinterface.core.old_api_utils import NewToOldRecording -from spikeinterface.core import load_extractor -from spikeinterface.extractors import NpzSortingExtractor, NumpySorting - -from packaging.version import parse +from spikeinterface.extractors import NpzSortingExtractor class Mountainsort5Sorter(BaseSorter): @@ -84,8 +79,10 @@ def is_installed(cls): HAVE_MS5 = True except ImportError: HAVE_MS5 = False + mountainsort5 = None if HAVE_MS5: + assert mountainsort5 vv = parse(mountainsort5.__version__) if vv < parse("0.3"): print( @@ -114,6 +111,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 + from mountainsort5.util import TemporaryDirectory, create_cached_recording recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -124,7 +122,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if p["filter"] and p["freq_min"] is not None and p["freq_max"] is not None: if verbose: print("filtering") - recording = bandpass_filter(recording=recording, freq_min=p["freq_min"], freq_max=p["freq_max"]) + # important to use dtype=float here + recording = bandpass_filter( + recording=recording, freq_min=p["freq_min"], freq_max=p["freq_max"], dtype=float + ) # Whiten if p["whiten"]: @@ -169,13 +170,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): block_sorting_parameters=scheme2_sorting_parameters, block_duration_sec=p["scheme3_block_duration_sec"] ) - scheme = p["scheme"] - if scheme == "1": - sorting = ms5.sorting_scheme1(recording=recording, sorting_parameters=scheme1_sorting_parameters) - elif p["scheme"] == "2": - sorting = ms5.sorting_scheme2(recording=recording, sorting_parameters=scheme2_sorting_parameters) - elif p["scheme"] == "3": - sorting = ms5.sorting_scheme3(recording=recording, sorting_parameters=scheme3_sorting_parameters) + assert isinstance(recording, BaseRecording) + + with TemporaryDirectory() as tmpdir: + # cache the recording to a temporary directory for efficient reading (so we don't have to re-filter) + recording_cached = create_cached_recording(recording=recording, folder=tmpdir) + + scheme = p["scheme"] + if scheme == "1": + sorting = ms5.sorting_scheme1(recording=recording_cached, sorting_parameters=scheme1_sorting_parameters) + elif p["scheme"] == "2": + sorting = ms5.sorting_scheme2(recording=recording_cached, sorting_parameters=scheme2_sorting_parameters) + elif p["scheme"] == "3": + sorting = ms5.sorting_scheme3(recording=recording_cached, sorting_parameters=scheme3_sorting_parameters) + else: + raise Exception(f"Invalid scheme: {scheme}") NpzSortingExtractor.write_sorting(sorting, str(sorter_output_folder / "firings.npz")) From 18ecd53bf9350f484f4dfb732486b55e6aaefafd Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 17 Nov 2023 14:02:41 -0500 Subject: [PATCH 2/8] improve exception msg in mountainsort5 Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sorters/external/mountainsort5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 0957a2321a..f8be972a22 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -184,7 +184,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): elif p["scheme"] == "3": sorting = ms5.sorting_scheme3(recording=recording_cached, sorting_parameters=scheme3_sorting_parameters) else: - raise Exception(f"Invalid scheme: {scheme}") + raise ValueError(f"Invalid scheme: {scheme} given. scheme must be one of '1', '2' or '3'") NpzSortingExtractor.write_sorting(sorting, str(sorter_output_folder / "firings.npz")) From 2390dc68cfc7652a13dd53da65a127f68a1d5c5e Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 17 Nov 2023 14:19:57 -0500 Subject: [PATCH 3/8] improve assertions in mountainsort5.py --- src/spikeinterface/sorters/external/mountainsort5.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 0957a2321a..929880a993 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -114,6 +114,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from mountainsort5.util import TemporaryDirectory, create_cached_recording recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + if recording is None: + raise Exception("Unable to load recording from folder.") + if not isinstance(recording, BaseRecording): + raise Exception("Unexpected: recording extracted from folder is not a BaseRecording") # alias to params p = params @@ -170,8 +174,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): block_sorting_parameters=scheme2_sorting_parameters, block_duration_sec=p["scheme3_block_duration_sec"] ) - assert isinstance(recording, BaseRecording) - with TemporaryDirectory() as tmpdir: # cache the recording to a temporary directory for efficient reading (so we don't have to re-filter) recording_cached = create_cached_recording(recording=recording, folder=tmpdir) From 5c15a3ddaf1bb65fae486971f590de0c2e861b31 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 17 Nov 2023 14:38:03 -0500 Subject: [PATCH 4/8] improve error message in mountainsort5.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sorters/external/mountainsort5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index f8f574fa89..2a0441f79c 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -117,7 +117,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if recording is None: raise Exception("Unable to load recording from folder.") if not isinstance(recording, BaseRecording): - raise Exception("Unexpected: recording extracted from folder is not a BaseRecording") + raise TypeError(f"Unexpected: recording extracted from folder is not a BaseRecording, but is of type: {type(recording)}") # alias to params p = params From d395dba4c7b2a442ccb21a1319e681470bc81810 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:38:20 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/external/mountainsort5.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 2a0441f79c..57d8f19cda 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -117,7 +117,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if recording is None: raise Exception("Unable to load recording from folder.") if not isinstance(recording, BaseRecording): - raise TypeError(f"Unexpected: recording extracted from folder is not a BaseRecording, but is of type: {type(recording)}") + raise TypeError( + f"Unexpected: recording extracted from folder is not a BaseRecording, but is of type: {type(recording)}" + ) # alias to params p = params From dc86474ec6f32702421ea91f32091f6386ff192a Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 17 Nov 2023 15:32:06 -0500 Subject: [PATCH 6/8] use np.float32 in mountainsort5.py --- src/spikeinterface/sorters/external/mountainsort5.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 57d8f19cda..2fa860c148 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -1,6 +1,8 @@ from pathlib import Path from packaging.version import parse +import numpy as np + from spikeinterface.preprocessing import bandpass_filter, whiten from spikeinterface.core.baserecording import BaseRecording @@ -130,7 +132,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print("filtering") # important to use dtype=float here recording = bandpass_filter( - recording=recording, freq_min=p["freq_min"], freq_max=p["freq_max"], dtype=float + recording=recording, freq_min=p["freq_min"], freq_max=p["freq_max"], dtype=np.float32 ) # Whiten From fd2258a9e50ad75320b91a039f0d6dc3e9fd2c73 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Mon, 20 Nov 2023 11:17:37 -0500 Subject: [PATCH 7/8] add params to mountainsort5.py --- .../sorters/external/mountainsort5.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 2fa860c148..cbd802b374 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -1,6 +1,8 @@ from pathlib import Path from packaging.version import parse +from tempfile import TemporaryDirectory + import numpy as np from spikeinterface.preprocessing import bandpass_filter, whiten @@ -39,6 +41,8 @@ class Mountainsort5Sorter(BaseSorter): "freq_max": 6000, "filter": True, "whiten": True, # Important to do whitening + "temporary_base_dir": None, + "n_jobs_for_preprocessing": 1, } _params_description = { @@ -62,6 +66,8 @@ class Mountainsort5Sorter(BaseSorter): "freq_max": "Low-pass filter cutoff frequency", "filter": "Enable or disable filter", "whiten": "Enable or disable whitening", + "temporary_base_dir": "Temporary directory base directory for storing cached recording", + "n_jobs_for_preprocessing": "Number of parallel jobs for creating the cached recording", } sorter_description = "MountainSort5 uses Isosplit clustering. It is an updated version of MountainSort4. See https://doi.org/10.1016/j.neuron.2017.08.030" @@ -113,7 +119,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - from mountainsort5.util import TemporaryDirectory, create_cached_recording + from mountainsort5.util import create_cached_recording recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if recording is None: @@ -178,9 +184,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): block_sorting_parameters=scheme2_sorting_parameters, block_duration_sec=p["scheme3_block_duration_sec"] ) - with TemporaryDirectory() as tmpdir: + with TemporaryDirectory(dir=p["temporary_base_dir"]) as tmpdir: # cache the recording to a temporary directory for efficient reading (so we don't have to re-filter) - recording_cached = create_cached_recording(recording=recording, folder=tmpdir) + recording_cached = create_cached_recording( + recording=recording, folder=tmpdir, n_jobs=p["n_jobs_for_preprocessing"] + ) scheme = p["scheme"] if scheme == "1": From 171329ef0ee98c575912830207ddffc7111ecbf8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 19 Jan 2024 11:01:43 +0100 Subject: [PATCH 8/8] Update src/spikeinterface/sorters/external/mountainsort5.py --- src/spikeinterface/sorters/external/mountainsort5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index cbd802b374..1fcd719325 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -42,7 +42,7 @@ class Mountainsort5Sorter(BaseSorter): "filter": True, "whiten": True, # Important to do whitening "temporary_base_dir": None, - "n_jobs_for_preprocessing": 1, + "n_jobs_for_preprocessing": -1, } _params_description = {