Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to kilosort 4: version >= 4.0.16, bad_channels, clear_cache, use_binary_file #3339

Merged
merged 22 commits into from
Sep 6, 2024
Merged
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
0ee0c43
Add bad channels and do version check
chrishalcrow Aug 26, 2024
bc290ff
remove comment about preprocesed spelling
chrishalcrow Aug 27, 2024
0df7141
Add use_binary_file argument and logic to KS4
alejoe91 Aug 27, 2024
a51a5b4
Merge branch 'main' into only-allow-above-ks-4-16
alejoe91 Aug 27, 2024
7638935
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Aug 29, 2024
227b0e7
Update KS4 versions
alejoe91 Aug 31, 2024
2b70788
Merge branch 'only-allow-above-ks-4-16' of github.com:chrishalcrow/sp…
alejoe91 Aug 31, 2024
c23d530
Update actions and always use binary if recording is binary
alejoe91 Sep 2, 2024
f9dfa04
Add highpass_cutoff and fix KS tests
alejoe91 Sep 3, 2024
1964f86
test ks4 on ks4 changes
alejoe91 Sep 3, 2024
8e9995d
Move testing scripts into scripts folder
alejoe91 Sep 3, 2024
219bee4
change trigger
alejoe91 Sep 3, 2024
e26e143
Remove last conditions on prior ks versions
alejoe91 Sep 3, 2024
9c338dd
Fix KS parameters in tests
alejoe91 Sep 3, 2024
87fbe55
More cleanup of KS4 tests
alejoe91 Sep 3, 2024
10b7e1a
Remove last change_nothing
alejoe91 Sep 3, 2024
007b64d
Allow use_binary_file=None (default) and add delete_recording_dat param
alejoe91 Sep 3, 2024
f399f6e
Update .github/scripts/test_kilosort4_ci.py
alejoe91 Sep 4, 2024
464c6e3
Update src/spikeinterface/sorters/external/kilosort4.py
alejoe91 Sep 4, 2024
0ed4876
Extend check on clus
alejoe91 Sep 4, 2024
8fbf100
Expose clear_Cache argument in KS4
alejoe91 Sep 5, 2024
fd61bb6
Explicitly add (spikeinterface parameter) to KS4 param description
alejoe91 Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 36 additions & 28 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Kilosort4Sorter(BaseSorter):
"scaleproc": None,
"save_preprocessed_copy": False,
"torch_device": "auto",
"bad_channels": None,
}

_params_description = {
Expand Down Expand Up @@ -101,6 +102,7 @@ class Kilosort4Sorter(BaseSorter):
"scaleproc": "int16 scaling of whitened data, if None set to 200.",
"save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data",
"torch_device": "Select the torch device auto/cuda/cpu",
"bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.",
}

sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching.
Expand All @@ -110,7 +112,7 @@ class Kilosort4Sorter(BaseSorter):
For more information see https://github.com/MouseLand/Kilosort"""

installation_mesg = """\nTo use Kilosort4 run:\n
>>> pip install kilosort==4.0
>>> pip install kilosort --upgrade

More information on Kilosort4 at:
https://github.com/MouseLand/Kilosort
Expand All @@ -134,6 +136,25 @@ def get_sorter_version(cls):
"""kilosort.__version__ <4.0.10 is always '4'"""
return importlib_version("kilosort")

@classmethod
def initialize_folder(cls, recording, output_folder, verbose, remove_existing_folder):
if not cls.is_installed():
raise Exception(
f"The sorter {cls.sorter_name} is not installed. Please install it with:\n{cls.installation_mesg}"
)
cls.check_sorter_version()
return super(Kilosort4Sorter, cls).initialize_folder(recording, output_folder, verbose, remove_existing_folder)

@classmethod
def check_sorter_version(cls):
kilosort_version = version.parse(cls.get_sorter_version())
if kilosort_version < version.parse("4.0.16"):
raise Exception(
f"""SpikeInterface only supports kilosort versions 4.0.16 and above. You are running version {kilosort_version}. To install the latest version, run:
>>> pip install kilosort --upgrade
"""
)

@classmethod
def _setup_recording(cls, recording, sorter_output_folder, params, verbose):
from probeinterface import write_prb
Expand Down Expand Up @@ -214,6 +235,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# NOTE: Also modifies settings in-place
data_dir = ""
results_dir = sorter_output_folder
bad_channels = params["bad_channels"]

filename, data_dir, results_dir, probe = set_files(
settings=settings,
Expand All @@ -222,36 +244,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
probe_name=probe_name,
data_dir=data_dir,
results_dir=results_dir,
bad_channels=bad_channels,
)

if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"):
ops = initialize_ops(
settings=settings,
probe=probe,
data_dtype=recording.get_dtype(),
do_CAR=do_CAR,
invert_sign=invert_sign,
device=device,
save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo)
)
else:
ops = initialize_ops(
settings=settings,
probe=probe,
data_dtype=recording.get_dtype(),
do_CAR=do_CAR,
invert_sign=invert_sign,
device=device,
)
ops = initialize_ops(
settings=settings,
probe=probe,
data_dtype=recording.get_dtype(),
do_CAR=do_CAR,
invert_sign=invert_sign,
device=device,
save_preprocessed_copy=save_preprocessed_copy, # this kwarg is correct (typo)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
)

if version.parse(cls.get_sorter_version()) >= version.parse("4.0.11"):
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (
get_run_parameters(ops)
)
else:
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = (
get_run_parameters(ops)
)
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (
get_run_parameters(ops)
)

# Set preprocessing and drift correction parameters
if not params["skip_kilosort_preprocessing"]:
Expand Down
Loading