Skip to content

Commit

Permalink
add params to mountainsort5.py
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Nov 20, 2023
1 parent dc86474 commit fd2258a
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/spikeinterface/sorters/external/mountainsort5.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path
from packaging.version import parse

from tempfile import TemporaryDirectory

import numpy as np

from spikeinterface.preprocessing import bandpass_filter, whiten
Expand Down Expand Up @@ -39,6 +41,8 @@ class Mountainsort5Sorter(BaseSorter):
"freq_max": 6000,
"filter": True,
"whiten": True, # Important to do whitening
"temporary_base_dir": None,
"n_jobs_for_preprocessing": 1,
}

_params_description = {
Expand All @@ -62,6 +66,8 @@ class Mountainsort5Sorter(BaseSorter):
"freq_max": "Low-pass filter cutoff frequency",
"filter": "Enable or disable filter",
"whiten": "Enable or disable whitening",
"temporary_base_dir": "Temporary directory base directory for storing cached recording",
"n_jobs_for_preprocessing": "Number of parallel jobs for creating the cached recording",
}

sorter_description = "MountainSort5 uses Isosplit clustering. It is an updated version of MountainSort4. See https://doi.org/10.1016/j.neuron.2017.08.030"
Expand Down Expand Up @@ -113,7 +119,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose):
@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
import mountainsort5 as ms5
from mountainsort5.util import TemporaryDirectory, create_cached_recording
from mountainsort5.util import create_cached_recording

recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)
if recording is None:
Expand Down Expand Up @@ -178,9 +184,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
block_sorting_parameters=scheme2_sorting_parameters, block_duration_sec=p["scheme3_block_duration_sec"]
)

with TemporaryDirectory() as tmpdir:
with TemporaryDirectory(dir=p["temporary_base_dir"]) as tmpdir:
# cache the recording to a temporary directory for efficient reading (so we don't have to re-filter)
recording_cached = create_cached_recording(recording=recording, folder=tmpdir)
recording_cached = create_cached_recording(
recording=recording, folder=tmpdir, n_jobs=p["n_jobs_for_preprocessing"]
)

scheme = p["scheme"]
if scheme == "1":
Expand Down

0 comments on commit fd2258a

Please sign in to comment.