From f5a29faca37af98c840b331b56b9198ce8f2bbe4 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 3 Aug 2023 13:09:36 +0100 Subject: [PATCH 01/25] Change the signature on kilosort's delete intermediate files. --- .../sorters/external/kilosort.py | 7 +++--- .../sorters/external/kilosort2.py | 7 +++--- .../sorters/external/kilosort2_5.py | 7 +++--- .../sorters/external/kilosort3.py | 7 +++--- .../sorters/external/kilosortbase.py | 24 ++++++++++++------- 5 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort.py b/src/spikeinterface/sorters/external/kilosort.py index 62b1e8b9e2..7a71761b21 100644 --- a/src/spikeinterface/sorters/external/kilosort.py +++ b/src/spikeinterface/sorters/external/kilosort.py @@ -42,8 +42,7 @@ class KilosortSorter(KilosortBase, BaseSorter): "Nfilt": None, "NT": None, "wave_length": 61, - "delete_tmp_files": True, - "delete_recording_dat": False, + "delete_intermediate_files": ("matlab_files",), } _params_description = { @@ -56,8 +55,8 @@ class KilosortSorter(KilosortBase, BaseSorter): "Nfilt": "Number of clusters to use (if None it is automatically computed)", "NT": "Batch size (if None it is automatically computed)", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", - "delete_tmp_files": "Whether to delete all temporary files after a successful run", - "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", + "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " + "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')" } sorter_description = """Kilosort is a GPU-accelerated and efficient template-matching spike sorter. diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index 267ff38e36..c46b3bc9c7 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -50,8 +50,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": False, "scaleproc": None, "save_rez_to_mat": False, - "delete_tmp_files": True, - "delete_recording_dat": False, + "delete_intermediate_files": ("matlab_files",), } _params_description = { @@ -73,8 +72,8 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", - "delete_tmp_files": "Whether to delete all temporary files after a successful run", - "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", + "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " + "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')" } sorter_description = """Kilosort2 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index 0c9e36177e..65c21e8b1f 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -57,8 +57,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": False, "scaleproc": None, "save_rez_to_mat": False, - "delete_tmp_files": True, - "delete_recording_dat": False, + "delete_intermediate_files": ("matlab_files",), } _params_description = { @@ -83,8 +82,8 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", - "delete_tmp_files": "Whether to delete all temporary files after a successful run", - "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", + "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " + "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')" } sorter_description = """Kilosort2_5 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index 77e83e35b9..d6f9b4639a 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -54,8 +54,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": False, "scaleproc": None, "save_rez_to_mat": False, - "delete_tmp_files": True, - "delete_recording_dat": False, + "delete_intermediate_files": ("matlab_files",), } _params_description = { @@ -80,8 +79,8 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", - "delete_tmp_files": "Whether to delete all temporary files after a successful run", - "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", + "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " + "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')" } sorter_description = """Kilosort3 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 9918d73edc..137d3e8f11 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -215,15 +215,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): raise Exception(f"{cls.sorter_name} returned a non-zero exit code") # Clean-up temporary files - if params["delete_tmp_files"]: - for temp_file in sorter_output_folder.glob("*.m"): - temp_file.unlink() - for temp_file in sorter_output_folder.glob("*.mat"): - temp_file.unlink() - if (sorter_output_folder / "temp_wh.dat").exists(): - (sorter_output_folder / "temp_wh.dat").unlink() - if params["delete_recording_dat"] and (recording_file := sorter_output_folder / "recording.dat").exists(): - recording_file.unlink() + print(f"Cleaning up temporary files created during sorting: " + f"{params['delete_intermediate_files']}") + + if "recording.dat" in params["delete_intermediate_files"]: + if (recording_file := sorter_output_folder / "recording.dat").exists(): + recording_file.unlink() + + if "temp_wh.dat" in params["delete_intermediate_files"]: + if (temp_wh_file := sorter_output_folder / "temp_wh.dat").exists(): + temp_wh_file.unlink() + + if "matlab_files" in params["delete_intermediate_files"]: + for ext in ["*.m", "*.mat"]: + for temp_file in sorter_output_folder.glob(ext): + temp_file.unlink() @classmethod def _get_result_from_folder(cls, sorter_output_folder): From 245a63227da8835a76ba4e7bc9ccf91fc9e4c1e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Aug 2023 16:50:05 +0000 Subject: [PATCH 02/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/external/kilosort.py | 2 +- src/spikeinterface/sorters/external/kilosort2.py | 2 +- src/spikeinterface/sorters/external/kilosort2_5.py | 2 +- src/spikeinterface/sorters/external/kilosort3.py | 2 +- src/spikeinterface/sorters/external/kilosortbase.py | 3 +-- 5 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort.py b/src/spikeinterface/sorters/external/kilosort.py index 7a71761b21..41e1cfd0fd 100644 --- a/src/spikeinterface/sorters/external/kilosort.py +++ b/src/spikeinterface/sorters/external/kilosort.py @@ -56,7 +56,7 @@ class KilosortSorter(KilosortBase, BaseSorter): "NT": "Batch size (if None it is automatically computed)", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " - "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')" + "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')", } sorter_description = """Kilosort is a GPU-accelerated and efficient template-matching spike sorter. diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index c46b3bc9c7..8e88503d02 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -73,7 +73,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " - "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')" + "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')", } sorter_description = """Kilosort2 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index 65c21e8b1f..d0818727d6 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -83,7 +83,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " - "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')" + "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')", } sorter_description = """Kilosort2_5 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index d6f9b4639a..a5a382d115 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -80,7 +80,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " - "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')" + "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')", } sorter_description = """Kilosort3 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 137d3e8f11..f1b32f74c3 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -215,8 +215,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): raise Exception(f"{cls.sorter_name} returned a non-zero exit code") # Clean-up temporary files - print(f"Cleaning up temporary files created during sorting: " - f"{params['delete_intermediate_files']}") + print(f"Cleaning up temporary files created during sorting: " f"{params['delete_intermediate_files']}") if "recording.dat" in params["delete_intermediate_files"]: if (recording_file := sorter_output_folder / "recording.dat").exists(): From 0c9b2127c994be07fb7a59c86e0236553cf50e46 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 31 Aug 2023 14:16:22 +0100 Subject: [PATCH 03/25] Revert some changes to ensure backwards compatability. --- .../sorters/external/kilosort.py | 9 ++++--- .../sorters/external/kilosort2.py | 9 ++++--- .../sorters/external/kilosort2_5.py | 9 ++++--- .../sorters/external/kilosort3.py | 9 ++++--- .../sorters/external/kilosortbase.py | 24 ++++++++++--------- 5 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort.py b/src/spikeinterface/sorters/external/kilosort.py index 41e1cfd0fd..e657157064 100644 --- a/src/spikeinterface/sorters/external/kilosort.py +++ b/src/spikeinterface/sorters/external/kilosort.py @@ -42,7 +42,8 @@ class KilosortSorter(KilosortBase, BaseSorter): "Nfilt": None, "NT": None, "wave_length": 61, - "delete_intermediate_files": ("matlab_files",), + "delete_tmp_files": ("matlab_files",), + "delete_recording_dat": False, } _params_description = { @@ -55,8 +56,10 @@ class KilosortSorter(KilosortBase, BaseSorter): "Nfilt": "Number of clusters to use (if None it is automatically computed)", "NT": "Batch size (if None it is automatically computed)", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", - "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " - "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')", + "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" + "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", + "delete_recording_dat": "Whether to delete the 'recording.dat' file after a " "successful run", } sorter_description = """Kilosort is a GPU-accelerated and efficient template-matching spike sorter. diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index 8e88503d02..a96bc87990 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -50,7 +50,8 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": False, "scaleproc": None, "save_rez_to_mat": False, - "delete_intermediate_files": ("matlab_files",), + "delete_tmp_files": ("matlab_files",), + "delete_recording_dat": False, } _params_description = { @@ -72,8 +73,10 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", - "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " - "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')", + "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" + "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", + "delete_recording_dat": "Whether to delete the 'recording.dat' file after a " "successful run", } sorter_description = """Kilosort2 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index d0818727d6..a7a8f8092e 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -57,7 +57,8 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": False, "scaleproc": None, "save_rez_to_mat": False, - "delete_intermediate_files": ("matlab_files",), + "delete_tmp_files": ("matlab_files",), + "delete_recording_dat": False, } _params_description = { @@ -82,8 +83,10 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", - "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " - "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')", + "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" + "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", + "delete_recording_dat": "Whether to delete the 'recording.dat' file after a " "successful run", } sorter_description = """Kilosort2_5 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index a5a382d115..b234b86025 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -54,7 +54,8 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": False, "scaleproc": None, "save_rez_to_mat": False, - "delete_intermediate_files": ("matlab_files",), + "delete_tmp_files": ("matlab_files",), + "delete_recording_dat": False, } _params_description = { @@ -79,8 +80,10 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", - "delete_intermediate_files": "Delete intermediate files created during sorting. Tuple indicating the " - "files to delete. Options are: ('recording.dat', 'temp_wh.dat', 'matlab_files')", + "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" + "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", + "delete_recording_dat": "Whether to delete the 'recording.dat' file after a " "successful run", } sorter_description = """Kilosort3 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index f1b32f74c3..b6cd348d3c 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -215,20 +215,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): raise Exception(f"{cls.sorter_name} returned a non-zero exit code") # Clean-up temporary files - print(f"Cleaning up temporary files created during sorting: " f"{params['delete_intermediate_files']}") + if params["delete_recording_dat"] and (recording_file := sorter_output_folder / "recording.dat").exists(): + recording_file.unlink() - if "recording.dat" in params["delete_intermediate_files"]: - if (recording_file := sorter_output_folder / "recording.dat").exists(): - recording_file.unlink() + if params["delete_tmp_files"]: + tmp_to_remove = ( + ("matlab_files", "temp_wh.dat") if params["delete_tmp_files"] is True else params["delete_tmp_files"] + ) - if "temp_wh.dat" in params["delete_intermediate_files"]: - if (temp_wh_file := sorter_output_folder / "temp_wh.dat").exists(): - temp_wh_file.unlink() + if "temp_wh.dat" in tmp_to_remove: + if (temp_wh_file := sorter_output_folder / "temp_wh.dat").exists(): + temp_wh_file.unlink() - if "matlab_files" in params["delete_intermediate_files"]: - for ext in ["*.m", "*.mat"]: - for temp_file in sorter_output_folder.glob(ext): - temp_file.unlink() + if "matlab_files" in tmp_to_remove: + for ext in ["*.m", "*.mat"]: + for temp_file in sorter_output_folder.glob(ext): + temp_file.unlink() @classmethod def _get_result_from_folder(cls, sorter_output_folder): From abc6d973e1425fc72c4e01f3c2b77e8cff0ff71a Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 31 Aug 2023 20:07:04 +0100 Subject: [PATCH 04/25] Minor changes to `_param_description` string formatting. --- src/spikeinterface/sorters/external/kilosort.py | 2 +- src/spikeinterface/sorters/external/kilosort2.py | 2 +- src/spikeinterface/sorters/external/kilosort2_5.py | 2 +- src/spikeinterface/sorters/external/kilosort3.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort.py b/src/spikeinterface/sorters/external/kilosort.py index e657157064..e5bc57a097 100644 --- a/src/spikeinterface/sorters/external/kilosort.py +++ b/src/spikeinterface/sorters/external/kilosort.py @@ -59,7 +59,7 @@ class KilosortSorter(KilosortBase, BaseSorter): "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", - "delete_recording_dat": "Whether to delete the 'recording.dat' file after a " "successful run", + "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } sorter_description = """Kilosort is a GPU-accelerated and efficient template-matching spike sorter. diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index a96bc87990..1f66263e39 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -76,7 +76,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", - "delete_recording_dat": "Whether to delete the 'recording.dat' file after a " "successful run", + "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } sorter_description = """Kilosort2 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index a7a8f8092e..6abe3e1e18 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -86,7 +86,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", - "delete_recording_dat": "Whether to delete the 'recording.dat' file after a " "successful run", + "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } sorter_description = """Kilosort2_5 is a GPU-accelerated and efficient template-matching spike sorter. On top of its diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index b234b86025..6da08c27e7 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -83,7 +83,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", - "delete_recording_dat": "Whether to delete the 'recording.dat' file after a " "successful run", + "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } sorter_description = """Kilosort3 is a GPU-accelerated and efficient template-matching spike sorter. On top of its From 859b0b3204a21d5bf2c4e93c06de19f2da195004 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 31 Aug 2023 20:12:57 +0100 Subject: [PATCH 05/25] Fix end-of-line spaces on `_params_description`. --- src/spikeinterface/sorters/external/kilosort.py | 4 ++-- src/spikeinterface/sorters/external/kilosort2.py | 4 ++-- src/spikeinterface/sorters/external/kilosort2_5.py | 6 +++--- src/spikeinterface/sorters/external/kilosort3.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort.py b/src/spikeinterface/sorters/external/kilosort.py index e5bc57a097..008b49a33d 100644 --- a/src/spikeinterface/sorters/external/kilosort.py +++ b/src/spikeinterface/sorters/external/kilosort.py @@ -56,8 +56,8 @@ class KilosortSorter(KilosortBase, BaseSorter): "Nfilt": "Number of clusters to use (if None it is automatically computed)", "NT": "Batch size (if None it is automatically computed)", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", - "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" - "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" + "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files) " "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index 1f66263e39..d42e1c9ab5 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -73,8 +73,8 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", - "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" - "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" + "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files) " "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index 6abe3e1e18..1aa69565f5 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -83,9 +83,9 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", - "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" - "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" - "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", + "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files) " + "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files') ", "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index 6da08c27e7..2f4f3e75b6 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -80,8 +80,8 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", - "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that" - "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files)" + "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files) " "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } From a299fa3378522b857b42c24f95c4a24fb86c3b93 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 21 Sep 2023 08:28:17 +0100 Subject: [PATCH 06/25] Update src/spikeinterface/sorters/external/kilosort2.py Co-authored-by: Alessio Buccino --- src/spikeinterface/sorters/external/kilosort2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index d42e1c9ab5..00ab3fbde5 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -74,7 +74,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " - "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files) " + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deletes all files) " "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } From 1d743905f72122970595948bc68f7f63bb01ccf8 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 21 Sep 2023 08:36:51 +0100 Subject: [PATCH 07/25] Update src/spikeinterface/sorters/external/kilosortbase.py Co-authored-by: Alessio Buccino --- src/spikeinterface/sorters/external/kilosortbase.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index b6cd348d3c..58947a2617 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -218,7 +218,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["delete_recording_dat"] and (recording_file := sorter_output_folder / "recording.dat").exists(): recording_file.unlink() - if params["delete_tmp_files"]: + if isinstance(params["delete_tmp_files"], bool): + if params["delete_tmp_files"]: + tmp_to_remove = ("matlab_files", "temp_wh.dat") + else: + tmp_to_remove = () + else: + assert isinstance(params["delete_tmp_files"], (tuple, list)), "..." + + if "temp_wh.dat" in tmp_to_remove: ... tmp_to_remove = ( ("matlab_files", "temp_wh.dat") if params["delete_tmp_files"] is True else params["delete_tmp_files"] ) From f8d60262485c1965d5a6a7b1a98923c82bfd4fba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Sep 2023 07:37:07 +0000 Subject: [PATCH 08/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/external/kilosortbase.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 58947a2617..efce2b0efb 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -225,8 +225,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): tmp_to_remove = () else: assert isinstance(params["delete_tmp_files"], (tuple, list)), "..." - - if "temp_wh.dat" in tmp_to_remove: ... + + if "temp_wh.dat" in tmp_to_remove: ... tmp_to_remove = ( ("matlab_files", "temp_wh.dat") if params["delete_tmp_files"] is True else params["delete_tmp_files"] ) From b4a1a7ec0842d86cded8e0525401b2909c4a010e Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 21 Sep 2023 08:46:36 +0100 Subject: [PATCH 09/25] Small fixups, add assert message, variable rename. --- .../sorters/external/kilosortbase.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 58947a2617..2a490d4646 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -220,25 +220,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if isinstance(params["delete_tmp_files"], bool): if params["delete_tmp_files"]: - tmp_to_remove = ("matlab_files", "temp_wh.dat") + tmp_files_to_remove = ("matlab_files", "temp_wh.dat") else: - tmp_to_remove = () + tmp_files_to_remove = () else: - assert isinstance(params["delete_tmp_files"], (tuple, list)), "..." - - if "temp_wh.dat" in tmp_to_remove: ... - tmp_to_remove = ( - ("matlab_files", "temp_wh.dat") if params["delete_tmp_files"] is True else params["delete_tmp_files"] - ) - - if "temp_wh.dat" in tmp_to_remove: - if (temp_wh_file := sorter_output_folder / "temp_wh.dat").exists(): - temp_wh_file.unlink() - - if "matlab_files" in tmp_to_remove: - for ext in ["*.m", "*.mat"]: - for temp_file in sorter_output_folder.glob(ext): - temp_file.unlink() + assert isinstance( + params["delete_tmp_files"], (tuple, list) + ), "`delete_tmp_files` must be a `Bool`, `Tuple` or `List`." + tmp_files_to_remove = params["delete_tmp_files"] + + if "temp_wh.dat" in tmp_files_to_remove: + if (temp_wh_file := sorter_output_folder / "temp_wh.dat").exists(): + temp_wh_file.unlink() + + if "matlab_files" in tmp_files_to_remove: + for ext in ["*.m", "*.mat"]: + for temp_file in sorter_output_folder.glob(ext): + temp_file.unlink() @classmethod def _get_result_from_folder(cls, sorter_output_folder): From a34cfbd7f51c534f1a58632f7420269ea49c9336 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 21 Sep 2023 08:58:59 +0100 Subject: [PATCH 10/25] Add another check on `delete_tmp_files` values. --- src/spikeinterface/sorters/external/kilosortbase.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 2a490d4646..2b8eb621b9 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -218,15 +218,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["delete_recording_dat"] and (recording_file := sorter_output_folder / "recording.dat").exists(): recording_file.unlink() + all_temp_files = ("matlab_files", "temp_wh.dat") + if isinstance(params["delete_tmp_files"], bool): if params["delete_tmp_files"]: - tmp_files_to_remove = ("matlab_files", "temp_wh.dat") + tmp_files_to_remove = all_tmp_files else: tmp_files_to_remove = () else: assert isinstance( params["delete_tmp_files"], (tuple, list) ), "`delete_tmp_files` must be a `Bool`, `Tuple` or `List`." + + for name in params["delete_tmp_files"]: + assert name in all_tmp_files, f"{name} is not a valid option, must be one of: {all_tmp_files}" + tmp_files_to_remove = params["delete_tmp_files"] if "temp_wh.dat" in tmp_files_to_remove: From ce9a72c33823eaa51254f77607cd2c0c15691a53 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 28 Sep 2023 21:05:42 +0200 Subject: [PATCH 11/25] Move UnitProbeMapWidget to new widgets API --- .../widgets/_legacy_mpl_widgets/__init__.py | 3 - .../tests/test_widgets_legacy.py | 6 +- .../widgets/tests/test_widgets.py | 10 ++- .../unitprobemap.py => unit_probe_map.py} | 81 +++++++++---------- src/spikeinterface/widgets/widget_list.py | 3 + 5 files changed, 54 insertions(+), 49 deletions(-) rename src/spikeinterface/widgets/{_legacy_mpl_widgets/unitprobemap.py => unit_probe_map.py} (65%) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index c10c78cbfc..ff144a9943 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -7,9 +7,6 @@ # waveform/PC related from .principalcomponent import plot_principal_component -# units on probe -from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget - from .multicompgraph import ( plot_multicomp_graph, MultiCompGraphWidget, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 39eb80e2e5..9aeb08698e 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -43,9 +43,9 @@ def setUp(self): def tearDown(self): pass - def test_plot_unit_probe_map(self): - sw.plot_unit_probe_map(self._we, with_channel_ids=True) - sw.plot_unit_probe_map(self._we, animated=True) + # def test_plot_unit_probe_map(self): + # sw.plot_unit_probe_map(self._we, with_channel_ids=True) + # sw.plot_unit_probe_map(self._we, animated=True) # def test_plot_units_depth_vs_amplitude(self): # sw.plot_units_depth_vs_amplitude(self._we) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index f44878927d..f1c3456305 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -348,6 +348,13 @@ def test_plot_rasters(self): if backend not in self.skip_backends: sw.plot_rasters(self.sorting) + def test_plot_unit_probe_map(self): + possible_backends = list(sw.UnitProbeMapWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_unit_probe_map(self.we) + + if __name__ == "__main__": # unittest.main() @@ -372,7 +379,8 @@ def test_plot_rasters(self): # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() - mytest.test_plot_rasters() + # mytest.test_plot_rasters() + mytest.test_plot_unit_probe_map() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitprobemap.py b/src/spikeinterface/widgets/unit_probe_map.py similarity index 65% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/unitprobemap.py rename to src/spikeinterface/widgets/unit_probe_map.py index 6522c736ea..66b7ff3126 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitprobemap.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -1,6 +1,11 @@ import numpy as np +from typing import Union -from .basewidget import BaseWidget +# from probeinterface import ProbeGroup + +from .base import BaseWidget, to_attr +# from .utils import get_unit_colors +from ..core.waveform_extractor import WaveformExtractor class UnitProbeMapWidget(BaseWidget): @@ -21,7 +26,6 @@ class UnitProbeMapWidget(BaseWidget): with_channel_ids: bool False default add channel ids text on the probe """ - def __init__( self, waveform_extractor, @@ -30,14 +34,10 @@ def __init__( animated=None, with_channel_ids=False, colorbar=True, - ncols=5, - axes=None, + backend=None, + **backend_kwargs, ): - from matplotlib.animation import FuncAnimation - from matplotlib import pyplot as plt - from probeinterface.plotting import plot_probe - self.waveform_extractor = waveform_extractor if unit_ids is None: unit_ids = waveform_extractor.sorting.unit_ids self.unit_ids = unit_ids @@ -45,44 +45,50 @@ def __init__( channel_ids = waveform_extractor.recording.channel_ids self.channel_ids = channel_ids - self.animated = animated - self.with_channel_ids = with_channel_ids - self.colorbar = colorbar - probes = waveform_extractor.recording.get_probes() - assert len(probes) == 1, ( - "Unit probe map is only available for a single probe. If you have a probe group, " - "consider splitting the recording from different probes" + data_plot = dict( + waveform_extractor=waveform_extractor, + unit_ids=unit_ids, + channel_ids=channel_ids, + animated=animated, + with_channel_ids=with_channel_ids, + colorbar=colorbar, ) - # layout - n = len(unit_ids) - if n < ncols: - ncols = n - nrows = int(np.ceil(n / ncols)) - if axes is None: - fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True) - BaseWidget.__init__(self, None, None, axes) - - def plot(self): - we = self.waveform_extractor + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from probeinterface.plotting import plot_probe + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + if backend_kwargs.get("axes", None) is None: + backend_kwargs["num_axes"] = len(dp.unit_ids) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + + we = dp.waveform_extractor probe = we.get_probe() probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) all_poly_contact = [] - for i, unit_id in enumerate(self.unit_ids): + for i, unit_id in enumerate(dp.unit_ids): ax = self.axes.flatten()[i] template = we.get_template(unit_id) # static - if self.animated: + if dp.animated: contacts_values = np.zeros(template.shape[1]) else: contacts_values = np.max(np.abs(template), axis=0) text_on_contact = None - if self.with_channel_ids: - text_on_contact = self.channel_ids - from probeinterface.plotting import plot_probe + if dp.with_channel_ids: + text_on_contact = dp.channel_ids poly_contact, poly_contour = plot_probe( probe, @@ -96,7 +102,7 @@ def plot(self): if poly_contour is not None: poly_contour.set_zorder(1) - if self.colorbar: + if dp.colorbar: self.figure.colorbar(poly_contact, ax=ax) poly_contact.set_clim(0, np.max(np.abs(template))) @@ -104,7 +110,7 @@ def plot(self): ax.set_title(str(unit_id)) - if self.animated: + if dp.animated: num_frames = template.shape[0] def animate_func(frame): @@ -118,12 +124,3 @@ def animate_func(frame): from matplotlib.animation import FuncAnimation self.animation = FuncAnimation(self.figure, animate_func, frames=num_frames, interval=20, blit=True) - - -def plot_unit_probe_map(*args, **kwargs): - W = UnitProbeMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_probe_map.__doc__ = UnitProbeMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index ed77de6128..525227a2e1 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -20,6 +20,7 @@ from .traces import TracesWidget from .unit_depths import UnitDepthsWidget from .unit_locations import UnitLocationsWidget +from .unit_probe_map import UnitProbeMapWidget from .unit_summary import UnitSummaryWidget from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget @@ -46,6 +47,7 @@ TracesWidget, UnitDepthsWidget, UnitLocationsWidget, + UnitProbeMapWidget, UnitSummaryWidget, UnitTemplatesWidget, UnitWaveformDensityMapWidget, @@ -107,6 +109,7 @@ plot_traces = TracesWidget plot_unit_depths = UnitDepthsWidget plot_unit_locations = UnitLocationsWidget +plot_unit_probe_map = UnitProbeMapWidget plot_unit_summary = UnitSummaryWidget plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget From e4144c589ab42ccab18e67931d39919a220e85b1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 14:13:29 +0200 Subject: [PATCH 12/25] Move isi distribution to new widget API. --- .../widgets/_legacy_mpl_widgets/__init__.py | 4 - .../tests/test_widgets_legacy.py | 11 +-- .../widgets/isi_distribution.py | 75 +++++++++++++++++++ src/spikeinterface/widgets/widget_list.py | 3 + 4 files changed, 82 insertions(+), 11 deletions(-) create mode 100644 src/spikeinterface/widgets/isi_distribution.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index ff144a9943..061fc55339 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,11 +1,7 @@ -# isi/ccg/acg -from .isidistribution import plot_isi_distribution, ISIDistributionWidget # peak activity from .activity import plot_peak_activity_map, PeakActivityMapWidget -# waveform/PC related -from .principalcomponent import plot_principal_component from .multicompgraph import ( plot_multicomp_graph, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 4e1bf445fc..9cd321db3c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -58,9 +58,6 @@ def tearDown(self): # def test_amplitudes_distribution(self): # sw.plot_amplitudes_distribution(self._we) - def test_principal_component(self): - sw.plot_principal_component(self._we) - # def test_plot_unit_localization(self): # sw.plot_unit_localization(self._we, with_channel_ids=True) # sw.plot_unit_localization(self._we, method='monopolar_triangulation') @@ -73,10 +70,10 @@ def test_principal_component(self): # unit_ids = self._sorting.unit_ids[:4] # sw.plot_crosscorrelograms(self._sorting, unit_ids=unit_ids, window_ms=500.0, bin_ms=20.0) - def test_isi_distribution(self): - sw.plot_isi_distribution(self._sorting, bin_ms=5.0, window_ms=500.0) - fig, axes = plt.subplots(self.num_units, 1) - sw.plot_isi_distribution(self._sorting, axes=axes) + # def test_isi_distribution(self): + # sw.plot_isi_distribution(self._sorting, bin_ms=5.0, window_ms=500.0) + # fig, axes = plt.subplots(self.num_units, 1) + # sw.plot_isi_distribution(self._sorting, axes=axes) def test_plot_peak_activity_map(self): sw.plot_peak_activity_map(self._rec, with_channel_ids=True) diff --git a/src/spikeinterface/widgets/isi_distribution.py b/src/spikeinterface/widgets/isi_distribution.py new file mode 100644 index 0000000000..2d92d1daf7 --- /dev/null +++ b/src/spikeinterface/widgets/isi_distribution.py @@ -0,0 +1,75 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + + + +class ISIDistributionWidget(BaseWidget): + """ + Plots spike train ISI distribution. + + Parameters + ---------- + sorting: SortingExtractor + The sorting extractor object + unit_ids: list + List of unit ids + bins_ms: int + Bin size in ms + window_ms: float + Window size in ms + + """ + + def __init__( + self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs + ): + + if unit_ids is None: + unit_ids = sorting.get_unit_ids() + + plot_data = dict( + sorting=sorting, + unit_ids=unit_ids, + window_ms=window_ms, + bin_ms=bin_ms, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + if backend_kwargs.get("axes", None) is None: + backend_kwargs["num_axes"] = len(dp.unit_ids) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + sorting = dp.sorting + num_segments = sorting.get_num_segments() + fs = sorting.sampling_frequency + + for i, unit_id in enumerate(dp.unit_ids): + ax = self.axes.flatten()[i] + + bins = np.arange(0, dp.window_ms, dp.bin_ms) + bin_counts = None + for segment_index in range(num_segments): + times_ms = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) / fs * 1000. + isi = np.diff(times_ms) + + bin_counts_, bin_edges = np.histogram(isi, bins=bins, density=True) + if segment_index == 0: + bin_counts = bin_counts_ + else: + bin_counts += bin_counts_ + # TODO handle sensity when several segments + + ax.bar(x=bin_edges[:-1], height=bin_counts, width=dp.bin_ms, color="gray", align="edge") + + ax.set_ylabel(f"{unit_id}") diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 525227a2e1..cec4b5ce53 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -8,6 +8,7 @@ from .autocorrelograms import AutoCorrelogramsWidget from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget +from .isi_distribution import ISIDistributionWidget from .motion import MotionWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget @@ -35,6 +36,7 @@ AutoCorrelogramsWidget, ConfusionMatrixWidget, CrossCorrelogramsWidget, + ISIDistributionWidget, MotionWidget, ProbeMapWidget, QualityMetricsWidget, @@ -97,6 +99,7 @@ plot_autocorrelograms = AutoCorrelogramsWidget plot_confusion_matrix = ConfusionMatrixWidget plot_crosscorrelograms = CrossCorrelogramsWidget +plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget From 86d073930c385efd9df991883dafde3a4897d2d9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 16 Oct 2023 17:32:01 +0200 Subject: [PATCH 13/25] oups --- src/spikeinterface/widgets/tests/test_widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 92ef4aa6c3..4443ef7b03 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -376,7 +376,7 @@ def test_plot_unit_probe_map(self): possible_backends = list(sw.UnitProbeMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_probe_map(self.we) + sw.plot_unit_probe_map(self.we_dense) From 4da65edd3ed0dbdc72cdb2e45c65163f3d59db55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 15:33:17 +0000 Subject: [PATCH 14/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../widgets/_legacy_mpl_widgets/__init__.py | 1 - src/spikeinterface/widgets/isi_distribution.py | 12 ++++-------- src/spikeinterface/widgets/tests/test_widgets.py | 2 -- src/spikeinterface/widgets/unit_probe_map.py | 5 ++--- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 061fc55339..53c2a5c79e 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,4 +1,3 @@ - # peak activity from .activity import plot_peak_activity_map, PeakActivityMapWidget diff --git a/src/spikeinterface/widgets/isi_distribution.py b/src/spikeinterface/widgets/isi_distribution.py index 2d92d1daf7..4256efd403 100644 --- a/src/spikeinterface/widgets/isi_distribution.py +++ b/src/spikeinterface/widgets/isi_distribution.py @@ -5,7 +5,6 @@ from .utils import get_unit_colors - class ISIDistributionWidget(BaseWidget): """ Plots spike train ISI distribution. @@ -20,13 +19,10 @@ class ISIDistributionWidget(BaseWidget): Bin size in ms window_ms: float Window size in ms - - """ - def __init__( - self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs - ): + """ + def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs): if unit_ids is None: unit_ids = sorting.get_unit_ids() @@ -53,14 +49,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting = dp.sorting num_segments = sorting.get_num_segments() fs = sorting.sampling_frequency - + for i, unit_id in enumerate(dp.unit_ids): ax = self.axes.flatten()[i] bins = np.arange(0, dp.window_ms, dp.bin_ms) bin_counts = None for segment_index in range(num_segments): - times_ms = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) / fs * 1000. + times_ms = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) / fs * 1000.0 isi = np.diff(times_ms) bin_counts_, bin_edges = np.histogram(isi, bins=bins, density=True) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 4443ef7b03..bc3ab4272a 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -245,7 +245,6 @@ def test_isi_distribution(self): **self.backend_kwargs[backend], ) - def test_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: @@ -377,7 +376,6 @@ def test_plot_unit_probe_map(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_probe_map(self.we_dense) - if __name__ == "__main__": diff --git a/src/spikeinterface/widgets/unit_probe_map.py b/src/spikeinterface/widgets/unit_probe_map.py index 66b7ff3126..4068c1c530 100644 --- a/src/spikeinterface/widgets/unit_probe_map.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -4,6 +4,7 @@ # from probeinterface import ProbeGroup from .base import BaseWidget, to_attr + # from .utils import get_unit_colors from ..core.waveform_extractor import WaveformExtractor @@ -26,6 +27,7 @@ class UnitProbeMapWidget(BaseWidget): with_channel_ids: bool False default add channel ids text on the probe """ + def __init__( self, waveform_extractor, @@ -37,7 +39,6 @@ def __init__( backend=None, **backend_kwargs, ): - if unit_ids is None: unit_ids = waveform_extractor.sorting.unit_ids self.unit_ids = unit_ids @@ -45,7 +46,6 @@ def __init__( channel_ids = waveform_extractor.recording.channel_ids self.channel_ids = channel_ids - data_plot = dict( waveform_extractor=waveform_extractor, unit_ids=unit_ids, @@ -71,7 +71,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - we = dp.waveform_extractor probe = we.get_probe() From 25ad1638b11da16d1ab9d3a72c0709fe7e81e320 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 16 Oct 2023 19:05:29 +0100 Subject: [PATCH 15/25] Update src/spikeinterface/sorters/external/kilosort3.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sorters/external/kilosort3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index 2f4f3e75b6..77267620fa 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -81,7 +81,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " - "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files) " + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deletes all files) " "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } From c412bd474565320e8fcc2db29ba82593135624d2 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 16 Oct 2023 19:05:40 +0100 Subject: [PATCH 16/25] Update src/spikeinterface/sorters/external/kilosort2_5.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sorters/external/kilosort2_5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index 1aa69565f5..dd9130b9ae 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -84,7 +84,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " - "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files) " + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deletes all files) " "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files') ", "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } From 80b69fa52f7779b71053c851beffd58919a8075d Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 16 Oct 2023 19:05:48 +0100 Subject: [PATCH 17/25] Update src/spikeinterface/sorters/external/kilosort.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sorters/external/kilosort.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort.py b/src/spikeinterface/sorters/external/kilosort.py index 008b49a33d..f1d656644b 100644 --- a/src/spikeinterface/sorters/external/kilosort.py +++ b/src/spikeinterface/sorters/external/kilosort.py @@ -57,7 +57,7 @@ class KilosortSorter(KilosortBase, BaseSorter): "NT": "Batch size (if None it is automatically computed)", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " - "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deltes all files) " + "contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deletes all files) " "or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')", "delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run", } From 285343af338c0337fcc17312632c94c8179d8a14 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 17 Oct 2023 08:55:27 -0400 Subject: [PATCH 18/25] add additional assert info --- src/spikeinterface/curation/curation_tools.py | 22 ++++++++++++++----- .../curation/curationsorting.py | 12 +++++----- .../curation/remove_redundant.py | 12 ++++++---- .../curation/splitunitsorting.py | 11 +++++----- 4 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index 38ff1f62c5..ddf7d4dc9d 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Optional import numpy as np @@ -9,9 +10,15 @@ except ModuleNotFoundError as err: HAVE_NUMBA = False +_methods = ("keep_first", "random", "keep_last", "keep_first_iterative", "keep_last_iterative") +_methods_numpy = ("keep_first", "random", "keep_last") + def _find_duplicated_spikes_numpy( - spike_train: np.ndarray, censored_period: int, seed: Optional[int] = None, method: str = "keep_first" + spike_train: np.ndarray, + censored_period: int, + seed: Optional[int] = None, + method: "keep_first" | "random" | "keep_last" = "keep_first", ) -> np.ndarray: (indices_of_duplicates,) = np.where(np.diff(spike_train) <= censored_period) @@ -29,7 +36,9 @@ def _find_duplicated_spikes_numpy( (indices_of_duplicates,) = np.where(~mask) elif method != "keep_last": - raise ValueError(f"Method '{method}' isn't a valid method for _find_duplicated_spikes_numpy.") + raise ValueError( + f"Method '{method}' isn't a valid method for _find_duplicated_spikes_numpy use one of {_methods_numpy}." + ) return indices_of_duplicates @@ -84,7 +93,10 @@ def _find_duplicated_spikes_keep_last_iterative(spike_train, censored_period): def find_duplicated_spikes( - spike_train, censored_period: int, method: str = "random", seed: Optional[int] = None + spike_train, + censored_period: int, + method: "keep_first" | "keep_last" | "keep_first_iterative" | "keep_last_iterative" | "random" = "random", + seed: Optional[int] = None, ) -> np.ndarray: """ Finds the indices where spikes should be considered duplicates. @@ -97,7 +109,7 @@ def find_duplicated_spikes( The spike train on which to look for duplicated spikes. censored_period: int The censored period for duplicates (in sample time). - method: str in ("keep_first", "keep_last", "keep_first_iterative', 'keep_last_iterative", random") + method: "keep_first" |"keep_last" | "keep_first_iterative' | 'keep_last_iterative" |random" Method used to remove the duplicated spikes. seed: int | None The seed to use if method="random". @@ -120,4 +132,4 @@ def find_duplicated_spikes( assert HAVE_NUMBA, "'keep_last' method requires numba. Install it with >>> pip install numba" return _find_duplicated_spikes_keep_last_iterative(spike_train.astype(np.int64), censored_period) else: - raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes.") + raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes. Use one of {_methods}") diff --git a/src/spikeinterface/curation/curationsorting.py b/src/spikeinterface/curation/curationsorting.py index f2776bafe6..bdb33e9eb1 100644 --- a/src/spikeinterface/curation/curationsorting.py +++ b/src/spikeinterface/curation/curationsorting.py @@ -148,24 +148,24 @@ def remove_empty_units(self): edges = None self._add_new_stage(new_sorting, edges) - def redo_avaiable(self): + def redo_available(self): # useful function for a gui return self._sorting_stages_i < len(self._sorting_stages) - def undo_avaiable(self): + def undo_available(self): # useful function for a gui return self._sorting_stages_i > 0 def undo(self): - if self.undo_avaiable(): + if self.undo_available(): self._sorting_stages_i -= 1 def redo(self): - if self.redo_avaiable(): + if self.redo_available(): self._sorting_stages_i += 1 def draw_graph(self, **kwargs): - assert self._make_graph, "to make a graph make_graph=True" + assert self._make_graph, "to make a graph use make_graph=True" graph = self.graph ids = [c.unit_id for c in graph.nodes] pos = {n: (n.stage_id, -ids.index(n.unit_id)) for n in graph.nodes} @@ -174,7 +174,7 @@ def draw_graph(self, **kwargs): @property def graph(self): - assert self._make_graph, "to have a graph make_graph=True" + assert self._make_graph, "to have a graph use make_graph=True" return self._graphs[self._sorting_stages_i] @property diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index c2617d5b52..e13f83550a 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np from spikeinterface import WaveformExtractor @@ -6,6 +7,9 @@ from ..postprocessing import align_sorting +_remove_strategies = ("minimum_shift", "highest_amplitude", "max_spikes") + + def remove_redundant_units( sorting_or_waveform_extractor, align=True, @@ -42,7 +46,7 @@ def remove_redundant_units( duplicate_threshold : float, optional Final threshold on the portion of coincident events over the number of spikes above which the unit is removed, by default 0.8 - remove_strategy: str + remove_strategy: 'minimum_shift' | 'highest_amplitude' | 'max_spikes', default: 'minimum_shift' Which strategy to remove one of the two duplicated units: * 'minimum_shift': keep the unit with best peak alignment (minimum shift) @@ -50,7 +54,7 @@ def remove_redundant_units( * 'highest_amplitude': keep the unit with the best amplitude on unshifted max. * 'max_spikes': keep the unit with more spikes - peak_sign: str ('neg', 'pos', 'both') + peak_sign: 'neg' |'pos' | 'both', default: 'neg' Used when remove_strategy='highest_amplitude' extra_outputs: bool If True, will return the redundant pairs. @@ -93,7 +97,7 @@ def remove_redundant_units( peak_values = {unit_id: np.max(np.abs(values)) for unit_id, values in peak_values.items()} if remove_strategy == "minimum_shift": - assert align, "remove_strategy with minimum_shift need align=True" + assert align, "remove_strategy with minimum_shift needs align=True" for u1, u2 in redundant_unit_pairs: if np.abs(unit_peak_shifts[u1]) > np.abs(unit_peak_shifts[u2]): remove_unit_ids.append(u1) @@ -125,7 +129,7 @@ def remove_redundant_units( # this will be implemented in a futur PR by the first who need it! raise NotImplementedError() else: - raise ValueError(f"remove_strategy : {remove_strategy} is not implemented!") + raise ValueError(f"remove_strategy : {remove_strategy} is not implemented! Options are {_remove_strategies}") sorting_clean = sorting.remove_units(remove_unit_ids) diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index 816d62cf9f..23863a85e5 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -21,11 +21,10 @@ class SplitUnitSorting(BaseSorting): be the same length as the spike train (for each segment) new_unit_ids: int Unit ids of the new units to be created. - properties_policy: str + properties_policy: 'keep' | 'remove', default: 'keep' Policy used to propagate properties. If 'keep' the properties will be passed to the new units (if the units_to_merge have the same value). If 'remove' the new units will have an empty value for all the properties of the new unit. - Default: 'keep' Returns ------- sorting: Sorting @@ -48,19 +47,19 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non new_unit_ids = np.array([u + new_unit_ids for u in range(tot_splits)], dtype=parents_unit_ids.dtype) else: new_unit_ids = np.array(new_unit_ids, dtype=parents_unit_ids.dtype) - assert len(np.unique(new_unit_ids)) == len(new_unit_ids), "Each element in new_unit_ids should be unique" - assert len(new_unit_ids) <= tot_splits, "indices_list have more ids indices than the length of new_unit_ids" + assert len(np.unique(new_unit_ids)) == len(new_unit_ids), "Each element in new_unit_ids must be unique" + assert len(new_unit_ids) <= tot_splits, "indices_list has more id indices than the length of new_unit_ids" assert parent_sorting.get_num_segments() == len( indices_list ), "The length of indices_list must be the same as parent_sorting.get_num_segments" - assert split_unit_id in parents_unit_ids, "Unit to split should be in parent sorting" + assert split_unit_id in parents_unit_ids, "Unit to split must be in parent sorting" assert properties_policy == "keep" or properties_policy == "remove", ( "properties_policy must be " "keep" " or " "remove" "" ) assert not any( np.isin(new_unit_ids, unchanged_units) - ), "new_unit_ids should be new units or one could be equal to split_unit_id" + ), "new_unit_ids should be new unit ids or no more than one unit id can be found in split_unit_id" sampling_frequency = parent_sorting.get_sampling_frequency() units_ids = np.concatenate([unchanged_units, new_unit_ids]) From 34854aa0cd115618ec3a99ce6503d3f28569cfe9 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 17 Oct 2023 09:00:28 -0400 Subject: [PATCH 19/25] add segment number to assert message --- src/spikeinterface/exporters/to_phy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 31a452f389..2c916d33b5 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -78,7 +78,7 @@ def export_to_phy( ), "waveform_extractor must be a WaveformExtractor object" sorting = waveform_extractor.sorting - assert waveform_extractor.get_num_segments() == 1, "Export to phy only works with one segment" + assert waveform_extractor.get_num_segments() == 1, f"Export to phy only works with one segment, your extractor has {waveform_extractor.get_num_segments()} segments" num_chans = waveform_extractor.get_num_channels() fs = waveform_extractor.sampling_frequency From 11677223154481d275ff3029af74d00c7d723f1c Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 17 Oct 2023 09:07:27 -0400 Subject: [PATCH 20/25] working on assert messaging --- src/spikeinterface/sorters/basesorter.py | 6 +++--- src/spikeinterface/sorters/launcher.py | 4 ++-- src/spikeinterface/sorters/sorterlist.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index a956f8c811..139f15bf12 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -103,7 +103,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo ) if not isinstance(recording, BaseRecordingSnippets): - raise ValueError("recording must be a Recording or Snippets!!") + raise ValueError("recording must be a Recording or a Snippets!!") if cls.requires_locations: locations = recording.get_channel_locations() @@ -133,7 +133,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo if recording.get_num_segments() > 1: if not cls.handle_multi_segment: raise ValueError( - f"This sorter {cls.sorter_name} do not handle multi segment, use si.concatenate_recordings(...)" + f"This sorter {cls.sorter_name} does not handle multi-segment recordings, use si.concatenate_recordings(...)" ) rec_file = output_folder / "spikeinterface_recording.json" @@ -299,7 +299,7 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_ # check errors in log file log_file = output_folder / "spikeinterface_log.json" if not log_file.is_file(): - raise SpikeSortingError("get result error: the folder does not contain the `spikeinterface_log.json` file") + raise SpikeSortingError("Get result error: the folder does not contain the `spikeinterface_log.json` file") with log_file.open("r", encoding="utf8") as f: log = json.load(f) diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 704f6843f2..e7fdedcfe7 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -374,7 +374,7 @@ def run_sorters( mode_if_folder_exists in ("raise", "keep", "overwrite") if mode_if_folder_exists == "raise" and working_folder.is_dir(): - raise Exception("working_folder already exists, please remove it") + raise Exception(f"working_folder {working_folder} already exists, please remove it") assert engine in _implemented_engine, f"engine must be in {_implemented_engine}" @@ -390,7 +390,7 @@ def run_sorters( elif isinstance(recording_dict_or_list, dict): recording_dict = recording_dict_or_list else: - raise ValueError("bad recording dict") + raise ValueError("Wrong format for recording_dict_or_list") dtype_rec_name = np.dtype(type(list(recording_dict.keys())[0])) assert dtype_rec_name.kind in ("i", "u", "S", "U"), "Dict keys can only be integers or strings!" diff --git a/src/spikeinterface/sorters/sorterlist.py b/src/spikeinterface/sorters/sorterlist.py index 40b5cdebaa..761bb6d716 100644 --- a/src/spikeinterface/sorters/sorterlist.py +++ b/src/spikeinterface/sorters/sorterlist.py @@ -89,7 +89,7 @@ def get_default_sorter_params(sorter_name_or_class): elif sorter_name_or_class in sorter_full_list: SorterClass = sorter_name_or_class else: - raise (ValueError("Unknown sorter")) + raise (ValueError(f"Unknown sorter {sorter_name_or_class} has been given")) return SorterClass.default_params() @@ -113,7 +113,7 @@ def get_sorter_params_description(sorter_name_or_class): elif sorter_name_or_class in sorter_full_list: SorterClass = sorter_name_or_class else: - raise (ValueError("Unknown sorter")) + raise (ValueError(f"Unknown sorter {sorter_name_or_class} has been given")) return SorterClass.params_description() @@ -137,6 +137,6 @@ def get_sorter_description(sorter_name_or_class): elif sorter_name_or_class in sorter_full_list: SorterClass = sorter_name_or_class else: - raise (ValueError("Unknown sorter")) + raise (ValueError(f"Unknown sorter {sorter_name_or_class} has been given")) return SorterClass.sorter_description From 44ad0ef0f0e29973e6e6c05fc2b992ee755db89b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Oct 2023 13:11:14 +0000 Subject: [PATCH 21/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/exporters/to_phy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 2c916d33b5..0529c99d12 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -78,7 +78,9 @@ def export_to_phy( ), "waveform_extractor must be a WaveformExtractor object" sorting = waveform_extractor.sorting - assert waveform_extractor.get_num_segments() == 1, f"Export to phy only works with one segment, your extractor has {waveform_extractor.get_num_segments()} segments" + assert ( + waveform_extractor.get_num_segments() == 1 + ), f"Export to phy only works with one segment, your extractor has {waveform_extractor.get_num_segments()} segments" num_chans = waveform_extractor.get_num_channels() fs = waveform_extractor.sampling_frequency From fec5316f68dd463693bcaf7383d89beaae99041c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 17 Oct 2023 15:25:56 +0200 Subject: [PATCH 22/25] Fix full tests on codecov --- src/spikeinterface/sortingcomponents/clustering/merge.py | 4 ++-- src/spikeinterface/sortingcomponents/clustering/split.py | 2 +- src/spikeinterface/widgets/metrics.py | 2 +- src/spikeinterface/widgets/tests/test_widgets.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index d35b562298..4c79383542 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -298,8 +298,8 @@ def find_merge_pairs( indices0, indices1 = np.nonzero(pair_mask) n_jobs = job_kwargs["n_jobs"] - mp_context = job_kwargs["mp_context"] - max_threads_per_process = job_kwargs["max_threads_per_process"] + mp_context = job_kwargs.get("mp_context", None) + max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) progress_bar = job_kwargs["progress_bar"] Executor = get_poolexecutor(n_jobs) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index a31e7d62fc..5ea9fb7bb2 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -61,7 +61,7 @@ def split_clusters( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs["max_threads_per_process"] + max_threads_per_process = job_kwargs.get("max_threads_per_process", None) original_labels = peak_labels peak_labels = peak_labels.copy() diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index c7b701c8b0..bc44e58a33 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -224,7 +224,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): metrics_sv = [] for col in metric_names: - dtype = metrics.iloc[0][col].dtype + dtype = np.array(metrics.iloc[0][col]).dtype metric = vv.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str) metrics_sv.append(metric) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 1a2fdf38d9..f60346ade0 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -376,9 +376,9 @@ def test_plot_rasters(self): # mytest.test_plot_unit_summary() # mytest.test_unit_locations() # mytest.test_quality_metrics() - # mytest.test_template_metrics() + mytest.test_template_metrics() # mytest.test_amplitudes() - mytest.test_plot_agreement_matrix() + # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() # mytest.test_plot_rasters() From fb5b1ef9c07bf7eccd3f37b014b59b5ffef1ce81 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 17 Oct 2023 16:23:53 +0200 Subject: [PATCH 23/25] Fix default --- src/spikeinterface/sortingcomponents/clustering/split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 5ea9fb7bb2..48ec26679e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -61,7 +61,7 @@ def split_clusters( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs.get("max_threads_per_process", None) + max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) original_labels = peak_labels peak_labels = peak_labels.copy() From 155610550142d4602bb37b2f9984932595c95fa0 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Oct 2023 19:05:41 +0200 Subject: [PATCH 24/25] add typing import --- .../extractors/neoextractors/neuroscope.py | 10 +--------- src/spikeinterface/postprocessing/template_metrics.py | 2 ++ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index 801b9c1928..c652ce4fb9 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -9,12 +9,6 @@ from .neobaseextractor import NeoBaseRecordingExtractor -try: - from lxml import etree as et - - HAVE_LXML = True -except ImportError: - HAVE_LXML = False PathType = Union[str, Path] OptionalPathType = Optional[PathType] @@ -108,8 +102,6 @@ class NeuroScopeSortingExtractor(BaseSorting): """ extractor_name = "NeuroscopeSortingExtractor" - installed = HAVE_LXML - installation_mesg = "Please install lxml to use this extractor!" name = "neuroscope" def __init__( @@ -121,7 +113,7 @@ def __init__( exclude_shanks: Optional[list] = None, xml_file_path: OptionalPathType = None, ): - assert self.installed, self.installation_mesg + from lxml import etree as et assert not ( folder_path is None and resfile_path is None and clufile_path is None diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 3f47c505ad..bd6d1eff2a 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -3,6 +3,8 @@ https://github.com/AllenInstitute/ecephys_spike_sorting/blob/master/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py 22/04/2020 """ +from __future__ import annotations + import numpy as np import warnings from typing import Optional From fedc2ce5c4578eaa3a757ed54b8dc70f79ed4d3e Mon Sep 17 00:00:00 2001 From: Henry Skelton Date: Wed, 18 Oct 2023 09:42:03 -0400 Subject: [PATCH 25/25] fixed typo in variable name --- src/spikeinterface/sorters/external/kilosortbase.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 2b8eb621b9..67ddb52ab4 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -218,7 +218,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["delete_recording_dat"] and (recording_file := sorter_output_folder / "recording.dat").exists(): recording_file.unlink() - all_temp_files = ("matlab_files", "temp_wh.dat") + all_tmp_files = ("matlab_files", "temp_wh.dat") if isinstance(params["delete_tmp_files"], bool): if params["delete_tmp_files"]: