From 8fbf100dfc0be3032a85f02a0cd857a42edea53a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:50:29 +0200 Subject: [PATCH] Expose clear_Cache argument in KS4 --- .github/scripts/test_kilosort4_ci.py | 24 +++++++++++++++++++ .../sorters/external/kilosort4.py | 23 +++++++++++++++--- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 1da2f2ba92..6eeb71f1dd 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -368,6 +368,30 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp with pytest.raises(AssertionError): check_sortings_equal(default_kilosort_sorting, sorting_si) + def test_clear_cache(self,recording_and_paths, tmp_path): + """ + Test clear_cache parameter in kilosort4.run_kilosort + """ + recording, paths = recording_and_paths + + spikeinterface_output_dir = tmp_path / "spikeinterface_output_clear" + sorting_si_clear = si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + clear_cache=True + ) + spikeinterface_output_dir = tmp_path / "spikeinterface_output_no_clear" + sorting_si_no_clear = si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + clear_cache=False + ) + check_sortings_equal(sorting_si_clear, sorting_si_no_clear) + def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): """ Test the SpikeInterface wrappers `do_correction` argument. We set diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 183f26d86c..4a8c9d1782 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -65,6 +65,7 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": False, "torch_device": "auto", "bad_channels": None, + "clear_cache": False, "use_binary_file": None, "delete_recording_dat": True, } @@ -111,6 +112,7 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", + "clear_cache": "If True, force pytorch to free up memory reserved for its cache in between memory-intensive operations. Note that setting `clear_cache=True` is NOT recommended unless you encounter GPU out-of-memory errors, since this can result in slower sorting.", "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. " "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " "Default is None.", @@ -284,6 +286,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): data_dir = "" results_dir = sorter_output_folder bad_channels = params["bad_channels"] + clear_cache = params["clear_cache"] filename, data_dir, results_dir, probe = set_files( settings=settings, @@ -347,17 +350,31 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # this function applies both preprocessing and drift correction ops, bfile, st0 = compute_drift_correction( - ops=ops, device=device, tic0=tic0, progress_bar=progress_bar, file_object=file_object + ops=ops, + device=device, + tic0=tic0, + progress_bar=progress_bar, + file_object=file_object, + clear_cache=clear_cache, ) if save_preprocessed_copy: save_preprocessing(results_dir / "temp_wh.dat", ops, bfile) # Sort spikes and save results - st, tF, _, _ = detect_spikes(ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar) + st, tF, _, _ = detect_spikes( + ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar, clear_cache=clear_cache + ) clu, Wall = cluster_spikes( - st=st, tF=tF, ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar + st=st, + tF=tF, + ops=ops, + device=device, + bfile=bfile, + tic0=tic0, + progress_bar=progress_bar, + clear_cache=clear_cache, ) if params["skip_kilosort_preprocessing"]: