diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 2fa860c148..cbd802b374 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -1,6 +1,8 @@ from pathlib import Path from packaging.version import parse +from tempfile import TemporaryDirectory + import numpy as np from spikeinterface.preprocessing import bandpass_filter, whiten @@ -39,6 +41,8 @@ class Mountainsort5Sorter(BaseSorter): "freq_max": 6000, "filter": True, "whiten": True, # Important to do whitening + "temporary_base_dir": None, + "n_jobs_for_preprocessing": 1, } _params_description = { @@ -62,6 +66,8 @@ class Mountainsort5Sorter(BaseSorter): "freq_max": "Low-pass filter cutoff frequency", "filter": "Enable or disable filter", "whiten": "Enable or disable whitening", + "temporary_base_dir": "Temporary directory base directory for storing cached recording", + "n_jobs_for_preprocessing": "Number of parallel jobs for creating the cached recording", } sorter_description = "MountainSort5 uses Isosplit clustering. It is an updated version of MountainSort4. See https://doi.org/10.1016/j.neuron.2017.08.030" @@ -113,7 +119,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - from mountainsort5.util import TemporaryDirectory, create_cached_recording + from mountainsort5.util import create_cached_recording recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if recording is None: @@ -178,9 +184,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): block_sorting_parameters=scheme2_sorting_parameters, block_duration_sec=p["scheme3_block_duration_sec"] ) - with TemporaryDirectory() as tmpdir: + with TemporaryDirectory(dir=p["temporary_base_dir"]) as tmpdir: # cache the recording to a temporary directory for efficient reading (so we don't have to re-filter) - recording_cached = create_cached_recording(recording=recording, folder=tmpdir) + recording_cached = create_cached_recording( + recording=recording, folder=tmpdir, n_jobs=p["n_jobs_for_preprocessing"] + ) scheme = p["scheme"] if scheme == "1":