diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 0162e6c943..1fcd719325 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -1,17 +1,16 @@ from pathlib import Path -from tempfile import tempdir from packaging.version import parse +from tempfile import TemporaryDirectory + +import numpy as np + from spikeinterface.preprocessing import bandpass_filter, whiten from spikeinterface.core.baserecording import BaseRecording from ..basesorter import BaseSorter -from spikeinterface.core.old_api_utils import NewToOldRecording -from spikeinterface.core import load_extractor -from spikeinterface.extractors import NpzSortingExtractor, NumpySorting - -from packaging.version import parse +from spikeinterface.extractors import NpzSortingExtractor class Mountainsort5Sorter(BaseSorter): @@ -42,6 +41,8 @@ class Mountainsort5Sorter(BaseSorter): "freq_max": 6000, "filter": True, "whiten": True, # Important to do whitening + "temporary_base_dir": None, + "n_jobs_for_preprocessing": -1, } _params_description = { @@ -65,6 +66,8 @@ class Mountainsort5Sorter(BaseSorter): "freq_max": "Low-pass filter cutoff frequency", "filter": "Enable or disable filter", "whiten": "Enable or disable whitening", + "temporary_base_dir": "Temporary directory base directory for storing cached recording", + "n_jobs_for_preprocessing": "Number of parallel jobs for creating the cached recording", } sorter_description = "MountainSort5 uses Isosplit clustering. It is an updated version of MountainSort4. See https://doi.org/10.1016/j.neuron.2017.08.030" @@ -84,8 +87,10 @@ def is_installed(cls): HAVE_MS5 = True except ImportError: HAVE_MS5 = False + mountainsort5 = None if HAVE_MS5: + assert mountainsort5 vv = parse(mountainsort5.__version__) if vv < parse("0.3"): print( @@ -114,8 +119,15 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 + from mountainsort5.util import create_cached_recording recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + if recording is None: + raise Exception("Unable to load recording from folder.") + if not isinstance(recording, BaseRecording): + raise TypeError( + f"Unexpected: recording extracted from folder is not a BaseRecording, but is of type: {type(recording)}" + ) # alias to params p = params @@ -124,7 +136,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if p["filter"] and p["freq_min"] is not None and p["freq_max"] is not None: if verbose: print("filtering") - recording = bandpass_filter(recording=recording, freq_min=p["freq_min"], freq_max=p["freq_max"]) + # important to use dtype=float here + recording = bandpass_filter( + recording=recording, freq_min=p["freq_min"], freq_max=p["freq_max"], dtype=np.float32 + ) # Whiten if p["whiten"]: @@ -169,13 +184,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): block_sorting_parameters=scheme2_sorting_parameters, block_duration_sec=p["scheme3_block_duration_sec"] ) - scheme = p["scheme"] - if scheme == "1": - sorting = ms5.sorting_scheme1(recording=recording, sorting_parameters=scheme1_sorting_parameters) - elif p["scheme"] == "2": - sorting = ms5.sorting_scheme2(recording=recording, sorting_parameters=scheme2_sorting_parameters) - elif p["scheme"] == "3": - sorting = ms5.sorting_scheme3(recording=recording, sorting_parameters=scheme3_sorting_parameters) + with TemporaryDirectory(dir=p["temporary_base_dir"]) as tmpdir: + # cache the recording to a temporary directory for efficient reading (so we don't have to re-filter) + recording_cached = create_cached_recording( + recording=recording, folder=tmpdir, n_jobs=p["n_jobs_for_preprocessing"] + ) + + scheme = p["scheme"] + if scheme == "1": + sorting = ms5.sorting_scheme1(recording=recording_cached, sorting_parameters=scheme1_sorting_parameters) + elif p["scheme"] == "2": + sorting = ms5.sorting_scheme2(recording=recording_cached, sorting_parameters=scheme2_sorting_parameters) + elif p["scheme"] == "3": + sorting = ms5.sorting_scheme3(recording=recording_cached, sorting_parameters=scheme3_sorting_parameters) + else: + raise ValueError(f"Invalid scheme: {scheme} given. scheme must be one of '1', '2' or '3'") NpzSortingExtractor.write_sorting(sorting, str(sorter_output_folder / "firings.npz"))