Skip to content

Commit

Permalink
Allow use_binary_file=None (default) and add delete_recording_dat param
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 3, 2024
1 parent 10b7e1a commit 007b64d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 28 deletions.
31 changes: 25 additions & 6 deletions .github/scripts/test_kilosort4_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,27 +415,46 @@ def test_use_binary_file(self, tmp_path):
sorting_ks4 = si.run_sorter(
"kilosort4",
recording,
folder=tmp_path / "spikeinterface_output_dir_wrapper",
use_binary_file=False,
folder=tmp_path / "ks4_output_si_wrapper_default",
use_binary_file=None,
remove_existing_folder=True,
)
sorting_ks4_bin = si.run_sorter(
"kilosort4",
recording_bin,
folder=tmp_path / "spikeinterface_output_dir_bin",
folder=tmp_path / "ks4_output_bin_default",
use_binary_file=None,
remove_existing_folder=True,
)
sorting_ks4_force_binary = si.run_sorter(
"kilosort4",
recording,
folder=tmp_path / "ks4_output_force_bin",
use_binary_file=True,
remove_existing_folder=True,
)
assert not (tmp_path / "ks4_output_force_bin" / "sorter_output" / "recording.dat").exists()
sorting_ks4_force_non_binary = si.run_sorter(
"kilosort4",
recording_bin,
folder=tmp_path / "ks4_output_force_wrapper",
use_binary_file=False,
remove_existing_folder=True,
)
sorting_ks4_non_bin = si.run_sorter(
# test deleting recording.dat
sorting_ks4_force_binary_keep = si.run_sorter(
"kilosort4",
recording,
folder=tmp_path / "spikeinterface_output_dir_non_bin",
folder=tmp_path / "ks4_output_force_bin_keep",
use_binary_file=True,
delete_recording_dat=False,
remove_existing_folder=True,
)
assert (tmp_path / "ks4_output_force_bin_keep" / "sorter_output" / "recording.dat").exists()

check_sortings_equal(sorting_ks4, sorting_ks4_bin)
check_sortings_equal(sorting_ks4, sorting_ks4_non_bin)
check_sortings_equal(sorting_ks4, sorting_ks4_force_binary)
check_sortings_equal(sorting_ks4, sorting_ks4_force_non_binary)

@pytest.mark.parametrize(
"param_to_test",
Expand Down
65 changes: 43 additions & 22 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class Kilosort4Sorter(BaseSorter):
"save_preprocessed_copy": False,
"torch_device": "auto",
"bad_channels": None,
"use_binary_file": False,
"use_binary_file": None,
"delete_recording_dat": True,
}

_params_description = {
Expand Down Expand Up @@ -110,8 +111,10 @@ class Kilosort4Sorter(BaseSorter):
"save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data",
"torch_device": "Select the torch device auto/cuda/cpu",
"bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.",
"use_binary_file": "If True and the recording is not binary compatible, then Kilosort is written to a binary file in the output folder. If False, the Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. "
"If the recording is binary compatible, then the sorter will always use the binary file. Default is False.",
"use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binaru compatible, it is written to a binary file in the output folder. "
"If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. "
"Default is None.",
"delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True.",
}

sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching.
Expand Down Expand Up @@ -172,15 +175,16 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose):
probe_filename = sorter_output_folder / "probe.prb"
write_prb(probe_filename, pg)

if params["use_binary_file"] and not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1):
# local copy needed
binary_file_path = sorter_output_folder / "recording.dat"
write_binary_recording(
recording=recording,
file_paths=[binary_file_path],
**get_job_kwargs(params, verbose),
)
params["filename"] = str(binary_file_path)
if params["use_binary_file"]:
if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1):
# local copy needed
binary_file_path = sorter_output_folder / "recording.dat"
write_binary_recording(
recording=recording,
file_paths=[binary_file_path],
**get_job_kwargs(params, verbose),
)
params["filename"] = str(binary_file_path)

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
Expand Down Expand Up @@ -227,18 +231,30 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
probe = load_probe(probe_path=probe_filename)
probe_name = ""

if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1):
# no copy
binary_description = recording.get_binary_description()
filename = str(binary_description["file_paths"][0])
file_object = None
if params["use_binary_file"] is None:
if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1):
# no copy
binary_description = recording.get_binary_description()
filename = str(binary_description["file_paths"][0])
file_object = None
else:
# the recording is not binary compatible and no binary copy has been written.
# in this case, we use the RecordingExtractorAsArray object
filename = ""
file_object = RecordingExtractorAsArray(recording_extractor=recording)
elif params["use_binary_file"]:
# a local copy has been written
filename = str(sorter_output_folder / "recording.dat")
file_object = None
# here we force the use of a binary file
if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1):
# no copy
binary_description = recording.get_binary_description()
filename = str(binary_description["file_paths"][0])
file_object = None
else:
# a local copy has been written
filename = str(sorter_output_folder / "recording.dat")
file_object = None
else:
# the recording is not binary compatible and no binary copy has been written.
# in this case, we use the RecordingExtractorAsArray object
# here we force the use of the RecordingExtractorAsArray object
filename = ""
file_object = RecordingExtractorAsArray(recording_extractor=recording)

Expand Down Expand Up @@ -362,6 +378,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
save_preprocessed_copy=save_preprocessed_copy,
)

if params["delete_recording_dat"]:
# only delete dat file if it was created by the wrapper
if (sorter_output_folder / "recording.dat").is_file():
(sorter_output_folder / "recording.dat").unlink()

@classmethod
def _get_result_from_folder(cls, sorter_output_folder):
return KilosortBase._get_result_from_folder(sorter_output_folder)

0 comments on commit 007b64d

Please sign in to comment.