Skip to content

Commit

Permalink
Merge pull request #2225 from magland/main
Browse files Browse the repository at this point in the history
A couple updates to mountainsort5 sorter
  • Loading branch information
alejoe91 authored Jan 22, 2024
2 parents 581d8d1 + 171329e commit fb3c9c9
Showing 1 changed file with 37 additions and 14 deletions.
51 changes: 37 additions & 14 deletions src/spikeinterface/sorters/external/mountainsort5.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from pathlib import Path
from tempfile import tempdir
from packaging.version import parse

from tempfile import TemporaryDirectory

import numpy as np

from spikeinterface.preprocessing import bandpass_filter, whiten

from spikeinterface.core.baserecording import BaseRecording
from ..basesorter import BaseSorter
from spikeinterface.core.old_api_utils import NewToOldRecording
from spikeinterface.core import load_extractor

from spikeinterface.extractors import NpzSortingExtractor, NumpySorting

from packaging.version import parse
from spikeinterface.extractors import NpzSortingExtractor


class Mountainsort5Sorter(BaseSorter):
Expand Down Expand Up @@ -42,6 +41,8 @@ class Mountainsort5Sorter(BaseSorter):
"freq_max": 6000,
"filter": True,
"whiten": True, # Important to do whitening
"temporary_base_dir": None,
"n_jobs_for_preprocessing": -1,
}

_params_description = {
Expand All @@ -65,6 +66,8 @@ class Mountainsort5Sorter(BaseSorter):
"freq_max": "Low-pass filter cutoff frequency",
"filter": "Enable or disable filter",
"whiten": "Enable or disable whitening",
"temporary_base_dir": "Temporary directory base directory for storing cached recording",
"n_jobs_for_preprocessing": "Number of parallel jobs for creating the cached recording",
}

sorter_description = "MountainSort5 uses Isosplit clustering. It is an updated version of MountainSort4. See https://doi.org/10.1016/j.neuron.2017.08.030"
Expand All @@ -84,8 +87,10 @@ def is_installed(cls):
HAVE_MS5 = True
except ImportError:
HAVE_MS5 = False
mountainsort5 = None

if HAVE_MS5:
assert mountainsort5
vv = parse(mountainsort5.__version__)
if vv < parse("0.3"):
print(
Expand Down Expand Up @@ -114,8 +119,15 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose):
@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
import mountainsort5 as ms5
from mountainsort5.util import create_cached_recording

recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)
if recording is None:
raise Exception("Unable to load recording from folder.")
if not isinstance(recording, BaseRecording):
raise TypeError(
f"Unexpected: recording extracted from folder is not a BaseRecording, but is of type: {type(recording)}"
)

# alias to params
p = params
Expand All @@ -124,7 +136,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if p["filter"] and p["freq_min"] is not None and p["freq_max"] is not None:
if verbose:
print("filtering")
recording = bandpass_filter(recording=recording, freq_min=p["freq_min"], freq_max=p["freq_max"])
# important to use dtype=float here
recording = bandpass_filter(
recording=recording, freq_min=p["freq_min"], freq_max=p["freq_max"], dtype=np.float32
)

# Whiten
if p["whiten"]:
Expand Down Expand Up @@ -169,13 +184,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
block_sorting_parameters=scheme2_sorting_parameters, block_duration_sec=p["scheme3_block_duration_sec"]
)

scheme = p["scheme"]
if scheme == "1":
sorting = ms5.sorting_scheme1(recording=recording, sorting_parameters=scheme1_sorting_parameters)
elif p["scheme"] == "2":
sorting = ms5.sorting_scheme2(recording=recording, sorting_parameters=scheme2_sorting_parameters)
elif p["scheme"] == "3":
sorting = ms5.sorting_scheme3(recording=recording, sorting_parameters=scheme3_sorting_parameters)
with TemporaryDirectory(dir=p["temporary_base_dir"]) as tmpdir:
# cache the recording to a temporary directory for efficient reading (so we don't have to re-filter)
recording_cached = create_cached_recording(
recording=recording, folder=tmpdir, n_jobs=p["n_jobs_for_preprocessing"]
)

scheme = p["scheme"]
if scheme == "1":
sorting = ms5.sorting_scheme1(recording=recording_cached, sorting_parameters=scheme1_sorting_parameters)
elif p["scheme"] == "2":
sorting = ms5.sorting_scheme2(recording=recording_cached, sorting_parameters=scheme2_sorting_parameters)
elif p["scheme"] == "3":
sorting = ms5.sorting_scheme3(recording=recording_cached, sorting_parameters=scheme3_sorting_parameters)
else:
raise ValueError(f"Invalid scheme: {scheme} given. scheme must be one of '1', '2' or '3'")

NpzSortingExtractor.write_sorting(sorting, str(sorter_output_folder / "firings.npz"))

Expand Down

0 comments on commit fb3c9c9

Please sign in to comment.