diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index a5561c2ffc..08e1ee6e1a 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -52,6 +52,8 @@ jobs: - name: Shows installed packages by pip, git-annex and cached testing files uses: ./.github/actions/show-test-environment - name: run tests + env: + HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell run: | source ${{ github.workspace }}/test_env/bin/activate pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index dad42e021b..2f5fa02a0f 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -132,6 +132,8 @@ jobs: - name: Test core run: ./.github/run_tests.sh core - name: Test extractors + env: + HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell if: ${{ steps.modules-changed.outputs.EXTRACTORS_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh "extractors and not streaming_extractors" - name: Test preprocessing diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9770856dfa..9cc1129ed2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.10.1 + rev: 23.11.0 hooks: - id: black files: ^src/ diff --git a/README.md b/README.md index d51f372848..977cf6eba4 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ With SpikeInterface, users can: ## Documentation -Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.99.0). +Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.99.1). Detailed documentation of the development version of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/latest). diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 67f1e52011..e95edb968c 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -74,6 +74,8 @@ dtype (unless specified otherwise): Some scaling pre-processors, such as :code:`whiten()` or :code:`zscore()`, will force the output to :code:`float32`. +When converting from a :code:`float` to an :code:`int`, the value will first be rounded to the nearest integer. + Available preprocessing ----------------------- diff --git a/doc/releases/0.99.1.rst b/doc/releases/0.99.1.rst new file mode 100644 index 0000000000..688f9f6a41 --- /dev/null +++ b/doc/releases/0.99.1.rst @@ -0,0 +1,13 @@ +.. _release0.99.1: + +SpikeInterface 0.99.1 release notes +----------------------------------- + +14th November 2023 + +Minor release with some bug fixes. + +* Fix crash when default start / end frame arguments on motion interpolation are used (#2176) +* Fix bug in `make_match_count_matrix()` when computing matching events (#2182, #2191, #2196) +* Fix maxwell tests by setting HDF5_PLUGIN_PATH env in action (#2161) +* Add read_npz_sorting to extractors module (#2183) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 33735a47fd..2232173e5a 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.99.1.rst releases/0.99.0.rst releases/0.98.2.rst releases/0.98.1.rst @@ -32,6 +33,12 @@ Release notes releases/0.9.1.rst +Version 0.99.1 +============== + +* Minor release with some bug fixes + + Version 0.99.0 ============== diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 3cd856d662..19ba6afd27 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -63,6 +63,9 @@ def compute_agreement_score(num_matches, num1, num2): def do_count_event(sorting): """ Count event for each units in a sorting. + + Kept for backward compatibility sorting.count_num_spikes_per_unit() is doing the same. + Parameters ---------- sorting: SortingExtractor @@ -75,14 +78,7 @@ def do_count_event(sorting): """ import pandas as pd - unit_ids = sorting.get_unit_ids() - ev_counts = np.zeros(len(unit_ids), dtype="int64") - for segment_index in range(sorting.get_num_segments()): - ev_counts += np.array( - [len(sorting.get_unit_spike_train(u, segment_index=segment_index)) for u in unit_ids], dtype="int64" - ) - event_counts = pd.Series(ev_counts, index=unit_ids) - return event_counts + return pd.Series(sorting.count_num_spikes_per_unit()) def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids, @@ -133,11 +129,9 @@ def compute_matching_matrix( delta_frames, ): """ - Compute a matrix representing the matches between two spike trains. - - Given two spike trains, this function finds matching spikes based on a temporal proximity criterion - defined by `delta_frames`. The resulting matrix indicates the number of matches between units - in `spike_frames_train1` and `spike_frames_train2`. + Internal function used by `make_match_count_matrix()`. + This function is for one segment only. + The loop over segment is done in `make_match_count_matrix()` Parameters ---------- @@ -164,31 +158,9 @@ def compute_matching_matrix( A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. - - Notes - ----- - This algorithm identifies matching spikes between two ordered spike trains. - By iterating through each spike in the first train, it compares them against spikes in the second train, - determining matches based on the two spikes frames being within `delta_frames` of each other. - - To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `, - which signifies the minimal index in the second spike train that might match the upcoming spike - in the first train. - - The logic can be summarized as follows: - 1. Iterate through each spike in the first train - 2. For each spike, find the first match in the second train. - 3. Save the index of the first match as the new `second_train_search_start ` - 3. For each match, find as many matches as possible from the first match onwards. - - An important condition here is that the same spike is not matched twice. This is managed by keeping track - of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` - - For more details on the rationale behind this approach, refer to the documentation of this module and/or - the metrics section in SpikeForest documentation. """ - matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) + matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint64) # Used to avoid the same spike matching twice last_match_frame1 = -np.ones_like(matching_matrix, dtype=np.int64) @@ -216,11 +188,11 @@ def compute_matching_matrix( unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] if ( - frame1 != last_match_frame1[unit_index1, unit_index2] - and frame2 != last_match_frame2[unit_index1, unit_index2] + index1 != last_match_frame1[unit_index1, unit_index2] + and index2 != last_match_frame2[unit_index1, unit_index2] ): - last_match_frame1[unit_index1, unit_index2] = frame1 - last_match_frame2[unit_index1, unit_index2] = frame2 + last_match_frame1[unit_index1, unit_index2] = index1 + last_match_frame2[unit_index1, unit_index2] = index2 matching_matrix[unit_index1, unit_index2] += 1 @@ -232,10 +204,65 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames): +def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=False): + """ + Computes a matrix representing the matches between two Sorting objects. + + Given two spike trains, this function finds matching spikes based on a temporal proximity criterion + defined by `delta_frames`. The resulting matrix indicates the number of matches between units + in `spike_frames_train1` and `spike_frames_train2` for each pair of units. + + Note that this algo is not symmetric and is biased with `sorting1` representing ground truth for the comparison + + Parameters + ---------- + sorting1 : Sorting + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + sorting2 : Sorting + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at + `spike_frames_train1[i]` and `spike_frames_train2[j]` are considered matching. + ensure_symmetry: bool, default False + If ensure_symmetry=True, then the algo is run two times by switching sorting1 and sorting2. + And the minimum of the two results is taken. + Returns + ------- + matching_matrix : ndarray + A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. + + Notes + ----- + This algorithm identifies matching spikes between two ordered spike trains. + By iterating through each spike in the first train, it compares them against spikes in the second train, + determining matches based on the two spikes frames being within `delta_frames` of each other. + + To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `, + which signifies the minimal index in the second spike train that might match the upcoming spike + in the first train. + + The logic can be summarized as follows: + 1. Iterate through each spike in the first train + 2. For each spike, find the first match in the second train. + 3. Save the index of the first match as the new `second_train_search_start ` + 3. For each match, find as many matches as possible from the first match onwards. + + An important condition here is that the same spike is not matched twice. This is managed by keeping track + of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` + There are corner cases where a spike can be counted twice in the spiketrain 2 if there are bouts of bursting activity + (below delta_frames) in the spiketrain 1. To ensure that the number of matches does not exceed the number of spikes, + we apply a final clip. + + + For more details on the rationale behind this approach, refer to the documentation of this module and/or + the metrics section in SpikeForest documentation. + """ + num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() - matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint64) spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) @@ -257,7 +284,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames): unit_indices1_sorted = spike_vector1["unit_index"] unit_indices2_sorted = spike_vector2["unit_index"] - matching_matrix += get_optimized_compute_matching_matrix()( + matching_matrix_seg = get_optimized_compute_matching_matrix()( sample_frames1_sorted, sample_frames2_sorted, unit_indices1_sorted, @@ -267,6 +294,26 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames): delta_frames, ) + if ensure_symmetry: + matching_matrix_seg_switch = get_optimized_compute_matching_matrix()( + sample_frames2_sorted, + sample_frames1_sorted, + unit_indices2_sorted, + unit_indices1_sorted, + num_units_sorting2, + num_units_sorting1, + delta_frames, + ) + matching_matrix_seg = np.maximum(matching_matrix_seg, matching_matrix_seg_switch.T) + + matching_matrix += matching_matrix_seg + + # ensure the number of match do not exceed the number of spike in train 2 + # this is a simple way to handle corner cases for bursting in sorting1 + spike_count2 = np.array(list(sorting2.count_num_spikes_per_unit().values())) + spike_count2 = spike_count2[np.newaxis, :] + matching_matrix = np.clip(matching_matrix, None, spike_count2) + # Build a data frame from the matching matrix import pandas as pd @@ -277,12 +324,12 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames): return match_event_counts_df -def make_agreement_scores(sorting1, sorting2, delta_frames): +def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True): """ Make the agreement matrix. No threshold (min_score) is applied at this step. - Note : this computation is symmetric. + Note : this computation is symmetric by default. Inverting sorting1 and sorting2 give the transposed matrix. Parameters @@ -293,7 +340,9 @@ def make_agreement_scores(sorting1, sorting2, delta_frames): The second sorting extractor delta_frames: int Number of frames to consider spikes coincident - + ensure_symmetry: bool, default: True + If ensure_symmetry is True, then the algo is run two times by switching sorting1 and sorting2. + And the minimum of the two results is taken. Returns ------- agreement_scores: array (float) @@ -309,7 +358,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames): event_counts1 = pd.Series(ev_counts1, index=unit1_ids) event_counts2 = pd.Series(ev_counts2, index=unit2_ids) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=ensure_symmetry) agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 7f21aa657f..02e74b7053 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -28,6 +28,7 @@ def __init__( delta_time=0.4, match_score=0.5, chance_score=0.1, + ensure_symmetry=False, n_jobs=1, verbose=False, ): @@ -55,6 +56,8 @@ def __init__( self.unit1_ids = self.sorting1.get_unit_ids() self.unit2_ids = self.sorting2.get_unit_ids() + self.ensure_symmetry = ensure_symmetry + self._do_agreement() self._do_matching() @@ -84,7 +87,9 @@ def _do_agreement(self): self.event_counts2 = do_count_event(self.sorting2) # matrix of event match count for each pair - self.match_event_count = make_match_count_matrix(self.sorting1, self.sorting2, self.delta_frames) + self.match_event_count = make_match_count_matrix( + self.sorting1, self.sorting2, self.delta_frames, ensure_symmetry=self.ensure_symmetry + ) # agreement matrix score for each pair self.agreement_scores = make_agreement_scores_from_count( @@ -151,6 +156,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, + ensure_symmetry=True, n_jobs=n_jobs, verbose=verbose, ) @@ -283,6 +289,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, + ensure_symmetry=False, n_jobs=n_jobs, verbose=verbose, ) diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index ab24678a1e..31adee8ca4 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -135,6 +135,56 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): assert_array_equal(result.to_numpy(), expected_result) +def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): + # More challenging condition, this was failing with the previous approach that used np.where and np.diff + # This actual implementation should fail but the "clip protection" by number of spike make the solution. + # This is cheating but acceptable for really corner cases (burst in the ground truth). + frames_spike_train1 = [100, 105, 110] + frames_spike_train2 = [ + 100, + 105, + ] + unit_indices1 = [0, 0, 0] + unit_indices2 = [ + 0, + 0, + ] + delta_frames = 20 # long enough, so all frames in both sortings are within each other reach + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + # this is easy because it is sorting2 centric + result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, ensure_symmetry=False) + expected_result = np.array([[2]]) + assert_array_equal(result.to_numpy(), expected_result) + + # this work only because we protect by clipping + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, ensure_symmetry=False) + expected_result = np.array([[2]]) + assert_array_equal(result.to_numpy(), expected_result) + + +def test_make_match_count_matrix_ensure_symmetry(): + frames_spike_train1 = [ + 100, + 102, + 105, + 120, + 1000, + ] + unit_indices1 = [0, 2, 1, 0, 0] + frames_spike_train2 = [101, 150, 1000] + unit_indices2 = [0, 1, 0] + delta_frames = 100 + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, ensure_symmetry=True) + result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, ensure_symmetry=True) + + assert_array_equal(result.T, result_T) + + def test_make_match_count_matrix_test_proper_search_in_the_second_train(): "Search exhaustively in the second train, but only within the delta_frames window, do not terminate search early" frames_spike_train1 = [500, 600, 800] @@ -174,7 +224,7 @@ def test_make_agreement_scores(): assert_array_equal(agreement_scores.values, ok) - # test if symetric + # test if symmetric agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames) assert_array_equal(agreement_scores, agreement_scores2.T) @@ -437,15 +487,17 @@ def test_do_count_score_and_perf(): test_make_match_count_matrix_with_mismatched_sortings() test_make_match_count_matrix_no_double_matching() test_make_match_count_matrix_repeated_matching_but_no_double_counting() + test_make_match_count_matrix_repeated_matching_but_no_double_counting_2() test_make_match_count_matrix_test_proper_search_in_the_second_train() + test_make_match_count_matrix_ensure_symmetry() - # test_make_agreement_scores() + test_make_agreement_scores() - # test_make_possible_match() - # test_make_best_match() - # test_make_hungarian_match() + test_make_possible_match() + test_make_best_match() + test_make_hungarian_match() - # test_do_score_labels() - # test_compare_spike_trains() - # test_do_confusion_matrix() - # test_do_count_score_and_perf() + test_do_score_labels() + test_compare_spike_trains() + test_do_confusion_matrix() + test_do_count_score_and_perf() diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 1a8674697a..7cde209a8d 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -316,35 +316,69 @@ def to_dict( recursive: bool = False, ) -> dict: """ - Make a nested serialized dictionary out of the extractor. The dictionary produced can be used to re-initialize - an extractor using load_extractor_from_dict(dump_dict) + Construct a nested dictionary representation of the extractor. + + This method facilitates the serialization of the extractor instance by converting it + to a dictionary. The resulting dictionary can be used to re-initialize the extractor + through the `load_extractor_from_dict` function. + + Examples + -------- + >>> dump_dict = original_extractor.to_dict() + >>> reloaded_extractor = load_extractor_from_dict(dump_dict) Parameters ---------- - include_annotations: bool, default: False - If True, all annotations are added to the dict - include_properties: bool, default: False - If True, all properties are added to the dict - relative_to: str, Path, or None, default: None - If not None, files and folders are serialized relative to this path - Used in waveform extractor to maintain relative paths to binary files even if the - containing folder / diretory is moved - folder_metadata: str, Path, or None - Folder with numpy `npy` files containing additional information (e.g. probe in BaseRecording) and properties. - recursive: bool, default: False - If True, all dicitionaries in the kwargs are expanded with `to_dict` as well + include_annotations : bool, default: False + Whether to include all annotations in the dictionary + include_properties : bool, default: False + Whether to include all properties in the dictionary, by default False. + relative_to : Union[str, Path, None], default: None + If provided, file and folder paths will be made relative to this path, + enabling portability in folder formats such as the waveform extractor, + by default None. + folder_metadata : Union[str, Path, None], default: None + Path to a folder containing additional metadata files (e.g., probe information in BaseRecording) + in numpy `npy` format, by default None. + recursive : bool, default: False + If True, recursively apply `to_dict` to dictionaries within the kwargs, by default False. + + Raises + ------ + ValueError + If `relative_to` is specified while `recursive` is False. Returns ------- - dump_dict: dict - A dictionary representation of the extractor. + dict + A dictionary representation of the extractor, with the following structure: + { + "class": , + "module": , (e.g. 'spikeinterface'), + "kwargs": , + "version": , + "relative_paths": , + "annotations": , + "properties": , + "folder_metadata": + } + + Notes + ----- + - The `relative_to` argument only has an effect if `recursive` is set to True. + - The `folder_metadata` argument will be made relative to `relative_to` if both are specified. + - The `version` field in the resulting dictionary reflects the version of the module + from which the extractor class originates. + - The full class attribute above is the full import of the class, e.g. + 'spikeinterface.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor' + - The module is usually 'spikeinterface', but can be different for custom extractors such as those of + SpikeForest or any other project that inherits the Extractor class from spikeinterface. """ - kwargs = self._kwargs - if relative_to and not recursive: raise ValueError("`relative_to` is only possible when `recursive=True`") + kwargs = self._kwargs if recursive: to_dict_kwargs = dict( include_annotations=include_annotations, @@ -366,27 +400,24 @@ def to_dict( new_kwargs[name] = transform_extractors_to_dict(value) kwargs = new_kwargs - class_name = str(type(self)).replace("", "") + + module_import_path = self.__class__.__module__ + class_name_no_path = self.__class__.__name__ + class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass' module = class_name.split(".")[0] - imported_module = importlib.import_module(module) - try: - version = imported_module.__version__ - except AttributeError: - version = "unknown" + imported_module = importlib.import_module(module) + module_version = getattr(imported_module, "__version__", "unknown") dump_dict = { "class": class_name, "module": module, "kwargs": kwargs, - "version": version, + "version": module_version, "relative_paths": (relative_to is not None), } - try: - dump_dict["version"] = imported_module.__version__ - except AttributeError: - dump_dict["version"] = "unknown" + dump_dict["version"] = module_version # Can be spikeinterface, spikefores, etc. if include_annotations: dump_dict["annotations"] = self._annotations @@ -805,7 +836,7 @@ def save_to_folder(self, name=None, folder=None, overwrite=False, verbose=True, * explicit sub-folder, implicit base-folder : `extractor.save(name="extarctor_name")` * generated: `extractor.save()` - The second option saves to subfolder "extarctor_name" in + The second option saves to subfolder "extractor_name" in "get_global_tmp_folder()". You can set the global tmp folder with: "set_global_tmp_folder("path-to-global-folder")" diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 94b08d8cc3..3c976c3de3 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -57,13 +57,13 @@ def add_sorting_segment(self, sorting_segment): self._sorting_segments.append(sorting_segment) sorting_segment.set_parent_extractor(self) - def get_sampling_frequency(self): + def get_sampling_frequency(self) -> float: return self._sampling_frequency - def get_num_segments(self): + def get_num_segments(self) -> int: return len(self._sorting_segments) - def get_num_samples(self, segment_index=None): + def get_num_samples(self, segment_index=None) -> int: """Returns the number of samples of the associated recording for a segment. Parameters @@ -82,7 +82,7 @@ def get_num_samples(self, segment_index=None): ), "This methods requires an associated recording. Call self.register_recording() first." return self._recording.get_num_samples(segment_index=segment_index) - def get_total_samples(self): + def get_total_samples(self) -> int: """Returns the total number of samples of the associated recording. Returns @@ -299,9 +299,11 @@ def count_num_spikes_per_unit(self) -> dict: return num_spikes - def count_total_num_spikes(self): + def count_total_num_spikes(self) -> int: """ - Get total number of spikes summed across segment and units. + Get total number of spikes in the sorting. + + This is the sum of all spikes in all segments across all units. Returns ------- @@ -310,9 +312,10 @@ def count_total_num_spikes(self): """ return self.to_spike_vector().size - def select_units(self, unit_ids, renamed_unit_ids=None): + def select_units(self, unit_ids, renamed_unit_ids=None) -> BaseSorting: """ - Selects a subset of units + Returns a new sorting object which contains only a selected subset of units. + Parameters ---------- @@ -331,9 +334,30 @@ def select_units(self, unit_ids, renamed_unit_ids=None): sub_sorting = UnitsSelectionSorting(self, unit_ids, renamed_unit_ids=renamed_unit_ids) return sub_sorting - def remove_units(self, remove_unit_ids): + def rename_units(self, new_unit_ids: np.ndarray | list) -> BaseSorting: + """ + Returns a new sorting object with renamed units. + + + Parameters + ---------- + new_unit_ids : numpy.array or list + List of new names for unit ids. + They should map positionally to the existing unit ids. + + Returns + ------- + BaseSorting + Sorting object with renamed units + """ + from spikeinterface import UnitsSelectionSorting + + sub_sorting = UnitsSelectionSorting(self, renamed_unit_ids=new_unit_ids) + return sub_sorting + + def remove_units(self, remove_unit_ids) -> BaseSorting: """ - Removes a subset of units + Returns a new sorting object with contains only a selected subset of units. Parameters ---------- @@ -343,7 +367,7 @@ def remove_units(self, remove_unit_ids): Returns ------- BaseSorting - Sorting object without removed units + Sorting without the removed units """ from spikeinterface import UnitsSelectionSorting @@ -353,7 +377,8 @@ def remove_units(self, remove_unit_ids): def remove_empty_units(self): """ - Removes units with empty spike trains + Returns a new sorting object which contains only units with at least one spike. + For multi-segments, a unit is considered empty if it contains no spikes in all segments. Returns ------- @@ -364,16 +389,12 @@ def remove_empty_units(self): return self.select_units(non_empty_units) def get_non_empty_unit_ids(self): - non_empty_units = [] - for segment_index in range(self.get_num_segments()): - for unit in self.get_unit_ids(): - if len(self.get_unit_spike_train(unit, segment_index=segment_index)) > 0: - non_empty_units.append(unit) - non_empty_units = np.unique(non_empty_units) - return non_empty_units + num_spikes_per_unit = self.count_num_spikes_per_unit() + + return np.array([unit_id for unit_id in self.unit_ids if num_spikes_per_unit[unit_id] != 0]) def get_empty_unit_ids(self): - unit_ids = self.get_unit_ids() + unit_ids = self.unit_ids empty_units = unit_ids[~np.isin(unit_ids, self.get_non_empty_unit_ids())] return empty_units @@ -389,7 +410,7 @@ def get_all_spike_trains(self, outputs="unit_id"): """ Return all spike trains concatenated. - This is deprecated use sorting.to_spike_vector() instead + This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead """ warnings.warn( @@ -429,7 +450,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac Construct a unique structured numpy vector concatenating all spikes with several fields: sample_index, unit_index, segment_index. - See also `get_all_spike_trains()` Parameters ---------- diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 69e043b640..1c8661d12d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1336,15 +1336,60 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): return channel_locations -def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, seed=None): +def generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=20.0, + max_iteration=100, + distance_strict=False, + seed=None, +): rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") - for dim in (0, 1): - lim0 = np.min(channel_locations[:, dim]) - margin_um - lim1 = np.max(channel_locations[:, dim]) + margin_um - units_locations[:, dim] = rng.uniform(lim0, lim1, size=num_units) + + minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um + minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um + + units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units) + units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units) units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units) + if minimum_distance is not None: + solution_found = False + renew_inds = None + for i in range(max_iteration): + distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2) + inds0, inds1 = np.nonzero(distances < minimum_distance) + mask = inds0 != inds1 + inds0 = inds0[mask] + inds1 = inds1[mask] + + if inds0.size > 0: + if renew_inds is None: + renew_inds = np.unique(inds0) + else: + # random only bad ones in the previous set + renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] + + units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) + units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) + units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) + else: + solution_found = True + break + + if not solution_found: + if distance_strict: + raise ValueError( + f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " + "You can use distance_strict=False or reduce minimum distance" + ) + else: + warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") + return units_locations @@ -1369,7 +1414,7 @@ def generate_ground_truth_recording( upsample_vector=None, generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), - generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), generate_templates_kwargs=dict(), dtype="float32", seed=None, diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 07a57f7807..3b8b6025ca 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -154,11 +154,8 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nd or a single sparsified waveform (template) with shape (num_samples, num_active_channels). """ - assert_msg = ( - "Waveforms must be dense to sparsify them. " - f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}" - ) - assert self.are_waveforms_dense(waveforms=waveforms), assert_msg + if self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): + return waveforms non_zero_indices = self.unit_id_to_channel_indices[unit_id] sparsified_waveforms = waveforms[..., non_zero_indices] @@ -189,16 +186,20 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nda """ non_zero_indices = self.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) - assert_msg = ( - "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " - f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels." - ) - assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg + if not self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): + error_message = ( + "Waveforms do not seem to be in the sparsity shape for this unit_id. The number of active channels is " + f"{num_active_channels}, but the waveform has non-zero values outsies of those active channels: \n" + f"{waveforms[..., num_active_channels:]}" + ) + raise ValueError(error_message) densified_shape = waveforms.shape[:-1] + (self.num_channels,) - densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype) - densified_waveforms[..., non_zero_indices] = waveforms + densified_waveforms = np.zeros(shape=densified_shape, dtype=waveforms.dtype) + # Maps the active channels to their original indices + densified_waveforms[..., non_zero_indices] = waveforms[..., :num_active_channels] return densified_waveforms @@ -208,7 +209,21 @@ def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> bool: non_zero_indices = self.unit_id_to_channel_indices[unit_id] num_active_channels = len(non_zero_indices) - return waveforms.shape[-1] == num_active_channels + + # If any channel is non-zero outside of the active channels, then the waveforms are not sparse + excess_zeros = waveforms[..., num_active_channels:].sum() + + return int(excess_zeros) == 0 + + def sparisfy_templates(self, templates_array: np.ndarray) -> np.ndarray: + max_num_active_channels = self.max_num_active_channels + sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_templates = np.zeros(shape=sparisfied_shape, dtype=templates_array.dtype) + for unit_index, unit_id in enumerate(self.unit_ids): + template = templates_array[unit_index, ...] + sparse_templates[unit_index, ...] = self.sparsify_waveforms(waveforms=template, unit_id=unit_id) + + return sparse_templates @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py new file mode 100644 index 0000000000..e6372c7082 --- /dev/null +++ b/src/spikeinterface/core/template.py @@ -0,0 +1,196 @@ +import numpy as np +import json +from dataclasses import dataclass, field, astuple +from .sparsity import ChannelSparsity + + +@dataclass +class Templates: + """ + A class to represent spike templates, which can be either dense or sparse. + + Parameters + ---------- + templates_array : np.ndarray + Array containing the templates data. + sampling_frequency : float + Sampling frequency of the templates. + nbefore : int + Number of samples before the spike peak. + sparsity_mask : np.ndarray or None, default: None + Boolean array indicating the sparsity pattern of the templates. + If `None`, the templates are considered dense. + channel_ids : np.ndarray, optional default: None + Array of channel IDs. If `None`, defaults to an array of increasing integers. + unit_ids : np.ndarray, optional default: None + Array of unit IDs. If `None`, defaults to an array of increasing integers. + check_for_consistent_sparsity : bool, optional default: None + When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the + structure fo the sparsity_masl. + + The following attributes are available after construction: + + Attributes + ---------- + num_units : int + Number of units in the templates. Automatically determined from `templates_array`. + num_samples : int + Number of samples per template. Automatically determined from `templates_array`. + num_channels : int + Number of channels in the templates. Automatically determined from `templates_array` or `sparsity_mask`. + nafter : int + Number of samples after the spike peak. Calculated as `num_samples - nbefore - 1`. + ms_before : float + Milliseconds before the spike peak. Calculated from `nbefore` and `sampling_frequency`. + ms_after : float + Milliseconds after the spike peak. Calculated from `nafter` and `sampling_frequency`. + sparsity : ChannelSparsity, optional + Object representing the sparsity pattern of the templates. Calculated from `sparsity_mask`. + If `None`, the templates are considered dense. + """ + + templates_array: np.ndarray + sampling_frequency: float + nbefore: int + + sparsity_mask: np.ndarray = None + channel_ids: np.ndarray = None + unit_ids: np.ndarray = None + + check_for_consistent_sparsity: bool = True + + num_units: int = field(init=False) + num_samples: int = field(init=False) + num_channels: int = field(init=False) + + nafter: int = field(init=False) + ms_before: float = field(init=False) + ms_after: float = field(init=False) + sparsity: ChannelSparsity = field(init=False, default=None) + + def __post_init__(self): + self.num_units, self.num_samples = self.templates_array.shape[:2] + if self.sparsity_mask is None: + self.num_channels = self.templates_array.shape[2] + else: + self.num_channels = self.sparsity_mask.shape[1] + + # Time and frames domain information + self.nafter = self.num_samples - self.nbefore + self.ms_before = self.nbefore / self.sampling_frequency * 1000 + self.ms_after = self.nafter / self.sampling_frequency * 1000 + + # Initialize sparsity object + if self.channel_ids is None: + self.channel_ids = np.arange(self.num_channels) + if self.unit_ids is None: + self.unit_ids = np.arange(self.num_units) + if self.sparsity_mask is not None: + self.sparsity = ChannelSparsity( + mask=self.sparsity_mask, + unit_ids=self.unit_ids, + channel_ids=self.channel_ids, + ) + + # Test that the templates are sparse if a sparsity mask is passed + if self.check_for_consistent_sparsity: + if not self._are_passed_templates_sparse(): + raise ValueError("Sparsity mask passed but the templates are not sparse") + + def get_dense_templates(self) -> np.ndarray: + # Assumes and object without a sparsity mask already has dense templates + if self.sparsity is None: + return self.templates_array + + densified_shape = (self.num_units, self.num_samples, self.num_channels) + dense_waveforms = np.zeros(shape=densified_shape, dtype=self.templates_array.dtype) + + for unit_index, unit_id in enumerate(self.unit_ids): + waveforms = self.templates_array[unit_index, ...] + dense_waveforms[unit_index, ...] = self.sparsity.densify_waveforms(waveforms=waveforms, unit_id=unit_id) + + return dense_waveforms + + def are_templates_sparse(self) -> bool: + return self.sparsity is not None + + def _are_passed_templates_sparse(self) -> bool: + """ + Tests if the templates passed to the init constructor are sparse + """ + are_templates_sparse = True + for unit_index, unit_id in enumerate(self.unit_ids): + waveforms = self.templates_array[unit_index, ...] + are_templates_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) + if not are_templates_sparse: + return False + + return are_templates_sparse + + def to_dict(self): + return { + "templates_array": self.templates_array, + "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask, + "channel_ids": self.channel_ids, + "unit_ids": self.unit_ids, + "sampling_frequency": self.sampling_frequency, + "nbefore": self.nbefore, + } + + @classmethod + def from_dict(cls, data): + return cls( + templates_array=np.asarray(data["templates_array"]), + sparsity_mask=None if data["sparsity_mask"] is None else np.asarray(data["sparsity_mask"]), + channel_ids=np.asarray(data["channel_ids"]), + unit_ids=np.asarray(data["unit_ids"]), + sampling_frequency=data["sampling_frequency"], + nbefore=data["nbefore"], + ) + + def to_json(self): + from spikeinterface.core.core_tools import SIJsonEncoder + + return json.dumps(self.to_dict(), cls=SIJsonEncoder) + + @classmethod + def from_json(cls, json_str): + return cls.from_dict(json.loads(json_str)) + + def __eq__(self, other): + """ + Necessary to compare templates because they naturally compare objects by equality of their fields + which is not possible for numpy arrays. Therefore, we override the __eq__ method to compare each numpy arrays + using np.array_equal instead + """ + if not isinstance(other, Templates): + return False + + # Convert the instances to tuples + self_tuple = astuple(self) + other_tuple = astuple(other) + + # Compare each field + for s_field, o_field in zip(self_tuple, other_tuple): + if isinstance(s_field, np.ndarray): + if not np.array_equal(s_field, o_field): + return False + + # Compare ChannelSparsity by its mask, unit_ids and channel_ids. + # Maybe ChannelSparsity should have its own __eq__ method + elif isinstance(s_field, ChannelSparsity): + if not isinstance(o_field, ChannelSparsity): + return False + + # Compare ChannelSparsity by its mask, unit_ids and channel_ids + if not np.array_equal(s_field.mask, o_field.mask): + return False + if not np.array_equal(s_field.unit_ids, o_field.unit_ids): + return False + if not np.array_equal(s_field.channel_ids, o_field.channel_ids): + return False + else: + if s_field != o_field: + return False + + return True diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 0bdd9aecdd..a35898b420 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -22,6 +22,7 @@ ) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal +from spikeinterface.core.generate import generate_sorting if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "core" @@ -169,6 +170,18 @@ def test_npy_sorting(): assert_raises(Exception, sorting.register_recording, rec) +def test_rename_units_method(): + num_units = 2 + durations = [1.0, 1.0] + + sorting = generate_sorting(num_units=num_units, durations=durations) + + new_unit_ids = ["a", "b"] + new_sorting = sorting.rename_units(new_unit_ids=new_unit_ids) + + assert np.array_equal(new_sorting.get_unit_ids(), new_unit_ids) + + def test_empty_sorting(): sorting = NumpySorting.from_unit_dict({}, 30000) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 9a9c61766f..7b51abcccb 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -4,6 +4,8 @@ import numpy as np from spikeinterface.core import load_extractor, extract_waveforms + +from probeinterface import generate_multi_columns_probe from spikeinterface.core.generate import ( generate_recording, generate_sorting, @@ -289,6 +291,40 @@ def test_generate_single_fake_waveform(): # plt.show() +def test_generate_unit_locations(): + seed = 0 + + probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) + channel_locations = probe.contact_positions + + num_units = 100 + minimum_distance = 20.0 + unit_locations = generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=minimum_distance, + max_iteration=500, + distance_strict=False, + seed=seed, + ) + distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) + dist_flat = np.triu(distances, k=1).flatten() + dist_flat = dist_flat[dist_flat > 0] + assert np.all(dist_flat > minimum_distance) + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.hist(dist_flat, bins = np.arange(0, 400, 10)) + # fig, ax = plt.subplots() + # from probeinterface.plotting import plot_probe + # plot_probe(probe, ax=ax) + # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20) + # plt.show() + + def test_generate_templates(): seed = 0 @@ -297,7 +333,7 @@ def test_generate_templates(): num_units = 10 margin_um = 15.0 channel_locations = generate_channel_locations(num_chans, num_columns, 20.0) - unit_locations = generate_unit_locations(num_units, channel_locations, margin_um, seed) + unit_locations = generate_unit_locations(num_units, channel_locations, margin_um=margin_um, seed=seed) sampling_frequency = 30000.0 ms_before = 1.0 @@ -436,7 +472,8 @@ def test_generate_ground_truth_recording(): # test_noise_generator_consistency_after_dump(strategy, None) # test_generate_recording() # test_generate_single_fake_waveform() + test_generate_unit_locations() # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() - test_generate_sorting_with_spikes_on_borders() + # test_generate_sorting_with_spikes_on_borders() diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py new file mode 100644 index 0000000000..40bb3f2b34 --- /dev/null +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -0,0 +1,86 @@ +import pytest +import numpy as np +import pickle +from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import ChannelSparsity + + +def generate_test_template(template_type): + num_units = 2 + num_samples = 5 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + sampling_frequency = 30_000 + nbefore = 2 + + if template_type == "dense": + return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) + elif template_type == "sparse": # sparse with sparse templates + sparsity_mask = np.array([[True, False, True], [False, True, False]]) + sparsity = ChannelSparsity( + mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels) + ) + + # Create sparse templates + sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels)) + for unit_index in range(num_units): + template = templates_array[unit_index, ...] + sparse_template = sparsity.sparsify_waveforms(waveforms=template, unit_id=unit_index) + sparse_templates_array[unit_index, :, : sparse_template.shape[1]] = sparse_template + + return Templates( + templates_array=sparse_templates_array, + sparsity_mask=sparsity_mask, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + ) + + elif template_type == "sparse_with_dense_templates": # sparse with dense templates + sparsity_mask = np.array([[True, False, True], [False, True, False]]) + + return Templates( + templates_array=templates_array, + sparsity_mask=sparsity_mask, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + ) + + +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) +def test_pickle_serialization(template_type, tmp_path): + template = generate_test_template(template_type) + + # Dump to pickle + pkl_path = tmp_path / "templates.pkl" + with open(pkl_path, "wb") as f: + pickle.dump(template, f) + + # Load from pickle + with open(pkl_path, "rb") as f: + template_reloaded = pickle.load(f) + + assert template == template_reloaded + + +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) +def test_json_serialization(template_type): + template = generate_test_template(template_type) + + json_str = template.to_json() + template_reloaded_from_json = Templates.from_json(json_str) + + assert template == template_reloaded_from_json + + +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) +def test_get_dense_templates(template_type): + template = generate_test_template(template_type) + dense_templates = template.get_dense_templates() + assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) + + +def test_initialization_fail_with_dense_templates(): + with pytest.raises(ValueError, match="Sparsity mask passed but the templates are not sparse"): + template = generate_test_template(template_type="sparse_with_dense_templates") diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index c97a727340..a81d36139d 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -506,7 +506,7 @@ def get_recording_property(self, key) -> np.ndarray: def get_sorting_property(self, key) -> np.ndarray: return self.sorting.get_property(key) - def get_extension_class(self, extension_name): + def get_extension_class(self, extension_name: str): """ Get extension class from name and check if registered. @@ -525,7 +525,7 @@ def get_extension_class(self, extension_name): ext_class = extensions_dict[extension_name] return ext_class - def is_extension(self, extension_name) -> bool: + def has_extension(self, extension_name: str) -> bool: """ Check if the extension exists in memory or in the folder. @@ -556,7 +556,15 @@ def is_extension(self, extension_name) -> bool: and "params" in self._waveforms_root[extension_name].attrs.keys() ) - def load_extension(self, extension_name): + def is_extension(self, extension_name) -> bool: + warn( + "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.has_extension(extension_name) + + def load_extension(self, extension_name: str): """ Load an extension from its name. The module of the extension must be loaded and registered. @@ -572,7 +580,7 @@ def load_extension(self, extension_name): The loaded instance of the extension """ if self.folder is not None and extension_name not in self._loaded_extensions: - if self.is_extension(extension_name): + if self.has_extension(extension_name): ext_class = self.get_extension_class(extension_name) ext = ext_class.load(self.folder, self) if extension_name not in self._loaded_extensions: @@ -588,7 +596,7 @@ def delete_extension(self, extension_name) -> None: extension_name: str The extension name. """ - assert self.is_extension(extension_name), f"The extension {extension_name} is not available" + assert self.has_extension(extension_name), f"The extension {extension_name} is not available" del self._loaded_extensions[extension_name] if self.folder is not None and (self.folder / extension_name).is_dir(): shutil.rmtree(self.folder / extension_name) @@ -610,7 +618,7 @@ def get_available_extension_names(self): """ extension_names_in_folder = [] for extension_class in self.extensions: - if self.is_extension(extension_class.extension_name): + if self.has_extension(extension_class.extension_name): extension_names_in_folder.append(extension_class.extension_name) return extension_names_in_folder diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 57a5ab0166..8b14930859 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -51,7 +51,7 @@ def export_report( unit_ids = sorting.unit_ids # load or compute spike_amplitudes - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): spike_amplitudes = we.load_extension("spike_amplitudes").get_data(outputs="by_unit") elif force_computation: spike_amplitudes = compute_spike_amplitudes(we, peak_sign=peak_sign, outputs="by_unit", **job_kwargs) @@ -62,7 +62,7 @@ def export_report( ) # load or compute quality_metrics - if we.is_extension("quality_metrics"): + if we.has_extension("quality_metrics"): metrics = we.load_extension("quality_metrics").get_data() elif force_computation: metrics = compute_quality_metrics(we) @@ -73,7 +73,7 @@ def export_report( ) # load or compute correlograms - if we.is_extension("correlograms"): + if we.has_extension("correlograms"): correlograms, bins = we.load_extension("correlograms").get_data() elif force_computation: correlograms, bins = compute_correlograms(we, window_ms=100.0, bin_ms=1.0) @@ -84,7 +84,7 @@ def export_report( ) # pre-compute unit locations if not done - if not we.is_extension("unit_locations"): + if not we.has_extension("unit_locations"): unit_locations = compute_unit_locations(we) output_folder = Path(output_folder).absolute() diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index ecc5b316ec..607aa3e846 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -196,7 +196,7 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - if waveform_extractor.is_extension("similarity"): + if waveform_extractor.has_extension("similarity"): tmc = waveform_extractor.load_extension("similarity") template_similarity = tmc.get_data() else: @@ -219,7 +219,7 @@ def export_to_phy( np.save(str(output_folder / "channel_groups.npy"), channel_groups) if compute_amplitudes: - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): sac = waveform_extractor.load_extension("spike_amplitudes") amplitudes = sac.get_data(outputs="concatenated") else: @@ -231,7 +231,7 @@ def export_to_phy( np.save(str(output_folder / "amplitudes.npy"), amplitudes) if compute_pc_features: - if waveform_extractor.is_extension("principal_components"): + if waveform_extractor.has_extension("principal_components"): pc = waveform_extractor.load_extension("principal_components") else: pc = compute_principal_components( @@ -264,7 +264,7 @@ def export_to_phy( channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups}) channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False) - if waveform_extractor.is_extension("quality_metrics"): + if waveform_extractor.has_extension("quality_metrics"): qm = waveform_extractor.load_extension("quality_metrics") qm_data = qm.get_data() for column_name in qm_data.columns: diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index 235dd705dc..f8198c3d18 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -13,6 +13,7 @@ ZarrRecordingExtractor, read_binary, read_zarr, + read_npz_sorting, ) # sorting/recording/event from neo diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index f7b445cdb9..010b22975c 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -566,12 +566,12 @@ def get_unit_spike_train( start_frame = 0 if end_frame is None: end_frame = np.inf - times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:] + spike_times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:] if self._timestamps is not None: - frames = np.searchsorted(times, self.timestamps).astype("int64") + frames = np.searchsorted(spike_times, self.timestamps).astype("int64") else: - frames = np.round(times * self._sampling_frequency).astype("int64") + frames = np.round(spike_times * self._sampling_frequency).astype("int64") return frames[(frames >= start_frame) & (frames < end_frame)] diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 64c6499767..14f94eb20b 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -278,7 +278,6 @@ class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] -@pytest.mark.skipif(ON_GITHUB, reason="Maxwell plugin not installed on GitHub") class MaxwellRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MaxwellRecordingExtractor downloads = ["maxwell"] diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index 71a19f30d3..253ca2e4ce 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -1,9 +1,11 @@ from pathlib import Path +import pickle import pytest import numpy as np import h5py +from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor if hasattr(pytest, "global_test_folder"): @@ -15,7 +17,7 @@ @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_recording_s3_nwb_ros3(): +def test_recording_s3_nwb_ros3(tmp_path): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -40,9 +42,18 @@ def test_recording_s3_nwb_ros3(): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" + tmp_file = tmp_path / "test_ros3_recording.pkl" + with open(tmp_file, "wb") as f: + pickle.dump(rec, f) + + with open(tmp_file, "rb") as f: + reloaded_recording = pickle.load(f) + + check_recordings_equal(rec, reloaded_recording) + @pytest.mark.streaming_extractors -def test_recording_s3_nwb_fsspec(): +def test_recording_s3_nwb_fsspec(tmp_path): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -67,11 +78,20 @@ def test_recording_s3_nwb_fsspec(): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" + tmp_file = tmp_path / "test_fsspec_recording.pkl" + with open(tmp_file, "wb") as f: + pickle.dump(rec, f) + + with open(tmp_file, "rb") as f: + reloaded_recording = pickle.load(f) + + check_recordings_equal(rec, reloaded_recording) + @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_sorting_s3_nwb_ros3(): +def test_sorting_s3_nwb_ros3(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" # we provide the 'sampling_frequency' because the NWB file does not the electrical series sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3") @@ -90,9 +110,18 @@ def test_sorting_s3_nwb_ros3(): assert spike_train.dtype == "int64" assert np.all(spike_train >= 0) + tmp_file = tmp_path / "test_ros3_sorting.pkl" + with open(tmp_file, "wb") as f: + pickle.dump(sort, f) + + with open(tmp_file, "rb") as f: + reloaded_sorting = pickle.load(f) + + check_sortings_equal(reloaded_sorting, sort) + @pytest.mark.streaming_extractors -def test_sorting_s3_nwb_fsspec(): +def test_sorting_s3_nwb_fsspec(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" # we provide the 'sampling_frequency' because the NWB file does not the electrical series sort = NwbSortingExtractor( @@ -113,6 +142,15 @@ def test_sorting_s3_nwb_fsspec(): assert spike_train.dtype == "int64" assert np.all(spike_train >= 0) + tmp_file = tmp_path / "test_fsspec_sorting.pkl" + with open(tmp_file, "wb") as f: + pickle.dump(sort, f) + + with open(tmp_file, "rb") as f: + reloaded_sorting = pickle.load(f) + + check_sortings_equal(reloaded_sorting, sort) + if __name__ == "__main__": test_recording_s3_nwb_ros3() diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index cf32e79b25..effd87007f 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -750,7 +750,7 @@ def compute_principal_components( >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ - if load_if_exists and waveform_extractor.is_extension(WaveformPrincipalComponent.extension_name): + if load_if_exists and waveform_extractor.has_extension(WaveformPrincipalComponent.extension_name): pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name) else: pc = WaveformPrincipalComponent.create(waveform_extractor) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 858af3ee08..f68081dbda 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -60,7 +60,7 @@ def _set_params( metric_names += get_multi_channel_template_metric_names() metrics_kwargs = metrics_kwargs or dict() params = dict( - metric_names=[str(name) for name in metric_names], + metric_names=[str(name) for name in np.unique(metric_names)], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b539bbd5d4..2bef246bc2 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -143,7 +143,7 @@ def _test_extension_folder(self, we, in_memory=False): # reload as an extension from we assert self.extension_class.extension_name in we.get_available_extension_names() - assert we.is_extension(self.extension_class.extension_name) + assert we.has_extension(self.extension_class.extension_name) ext = we.load_extension(self.extension_class.extension_name) assert isinstance(ext, self.extension_class) for ext_name in self.extension_data_names: diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 49591d9b89..f5e315b18f 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -135,7 +135,7 @@ def test_project_new(self): from sklearn.decomposition import IncrementalPCA we = self.we1 - if we.is_extension("principal_components"): + if we.has_extension("principal_components"): we.delete_extension("principal_components") we_cp = we.select_units(we.unit_ids, self.cache_folder / "toy_waveforms_1seg_cp") diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 1d6947be79..172c666d62 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -153,6 +153,10 @@ def get_traces(self, start_frame, end_frame, channel_indices): filtered_traces = filtered_traces[left_margin:-right_margin, :] else: filtered_traces = filtered_traces[left_margin:, :] + + if np.issubdtype(self.dtype, np.integer): + filtered_traces = filtered_traces.round() + return filtered_traces.astype(self.dtype) diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index 79b5ba5bc3..325ce82074 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -74,6 +74,9 @@ def get_traces( filtered_fft = traces_fft * (gauss_high - gauss_low)[:, None] filtered_traces = np.real(np.fft.ifft(filtered_fft, axis=0)) + if np.issubdtype(dtype, np.integer): + filtered_traces = filtered_traces.round() + if right_margin > 0: return filtered_traces[left_margin:-right_margin, :].astype(dtype) else: diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 03afada380..f24aff6e79 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -20,6 +20,10 @@ def __init__(self, parent_recording_segment, gain, offset, dtype): def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) scaled_traces = traces * self.gain[:, channel_indices] + self.offset[:, channel_indices] + + if np.issubdtype(self._dtype, np.integer): + scaled_traces = scaled_traces.round() + return scaled_traces.astype(self._dtype) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 570ce48a5d..0734dad784 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -103,6 +103,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces_shift = traces_shift[left_margin:-right_margin, :] if self.tmp_dtype is not None: + if np.issubdtype(self.dtype, np.integer): + traces_shift = traces_shift.round() traces_shift = traces_shift.astype(self.dtype) return traces_shift diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 5c734b9100..9dab06124b 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -201,7 +201,7 @@ def compute_snrs( snrs : dict Computed signal to noise ratio for each unit. """ - if waveform_extractor.is_extension("noise_levels"): + if waveform_extractor.has_extension("noise_levels"): noise_levels = waveform_extractor.load_extension("noise_levels").get_data() else: if random_chunk_kwargs_dict is None: @@ -687,7 +687,7 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.is_extension(amplitude_extension): + if waveform_extractor.has_extension(amplitude_extension): sac = waveform_extractor.load_extension(amplitude_extension) amps = sac.get_data(outputs="concatenated") if amplitude_extension == "spike_amplitudes": @@ -803,7 +803,7 @@ def compute_amplitude_cutoffs( spike_amplitudes = None invert_amplitudes = False - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): amp_calculator = waveform_extractor.load_extension("spike_amplitudes") spike_amplitudes = amp_calculator.get_data(outputs="by_unit") if amp_calculator._params["peak_sign"] == "pos": @@ -881,7 +881,7 @@ def compute_amplitude_medians(waveform_extractor, peak_sign="neg", unit_ids=None extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign) spike_amplitudes = None - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): amp_calculator = waveform_extractor.load_extension("spike_amplitudes") spike_amplitudes = amp_calculator.get_data(outputs="by_unit") @@ -974,7 +974,7 @@ def compute_drift_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.is_extension("spike_locations"): + if waveform_extractor.has_extension("spike_locations"): locs_calculator = waveform_extractor.load_extension("spike_locations") spike_locations = locs_calculator.get_data(outputs="concatenated") spike_locations_by_unit = locs_calculator.get_data(outputs="by_unit") diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 53309db282..54b1027305 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -42,14 +42,14 @@ def _set_params( if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list - if self.waveform_extractor.is_extension("principal_components"): + if self.waveform_extractor.has_extension("principal_components"): # by default 'nearest_neightbor' is removed because too slow pc_metrics = _possible_pc_metric_names.copy() pc_metrics.remove("nn_isolation") pc_metrics.remove("nn_noise_overlap") metric_names += pc_metrics # if spike_locations are not available, drift is removed from the list - if not self.waveform_extractor.is_extension("spike_locations"): + if not self.waveform_extractor.has_extension("spike_locations"): if "drift" in metric_names: metric_names.remove("drift") @@ -61,7 +61,7 @@ def _set_params( qm_params_[k]["peak_sign"] = peak_sign params = dict( - metric_names=[str(name) for name in metric_names], + metric_names=[str(name) for name in np.unique(metric_names)], sparsity=sparsity, peak_sign=peak_sign, seed=seed, @@ -130,7 +130,7 @@ def _run(self, verbose, **job_kwargs): # metrics based on PCs pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self._params["skip_pc_metrics"]: - if not self.waveform_extractor.is_extension("principal_components"): + if not self.waveform_extractor.has_extension("principal_components"): raise ValueError("waveform_principal_component must be provied") pc_extension = self.waveform_extractor.load_extension("principal_components") pc_metrics = calculate_pc_metrics( @@ -216,7 +216,7 @@ def compute_quality_metrics( metrics: pandas.DataFrame Data frame with the computed metrics """ - if load_if_exists and waveform_extractor.is_extension(QualityMetricCalculator.extension_name): + if load_if_exists and waveform_extractor.has_extension(QualityMetricCalculator.extension_name): qmc = waveform_extractor.load_extension(QualityMetricCalculator.extension_name) else: qmc = QualityMetricCalculator(waveform_extractor) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index eb8317e4df..b601e5d6d8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -88,7 +88,7 @@ def test_metrics(self): we = self.we_long # avoid NaNs - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): we.delete_extension("spike_amplitudes") # without PC diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a16b642dd5..fd283a8224 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -33,6 +33,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "job_kwargs": {"n_jobs": -1}, } + handle_multi_segment = True + @classmethod def get_sorter_version(cls): return "2.0" diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index e256915fa6..eb2ddc922d 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -50,6 +50,8 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "save_array": True, } + handle_multi_segment = True + @classmethod def get_sorter_version(cls): return "2.0" diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index a5d3cb2429..6ff837065b 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -107,7 +107,7 @@ def check_extensions(waveform_extractor, extensions): error_msg = "" raise_error = False for extension in extensions: - if not waveform_extractor.is_extension(extension): + if not waveform_extractor.has_extension(extension): raise_error = True error_msg += ( f"The {extension} waveform extension is required for this widget. " diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 35fde07326..aa280ad658 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -80,13 +80,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 3 - if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"): + if we.has_extension("correlograms") or we.has_extension("spike_amplitudes"): ncols += 1 - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) - if we.is_extension("unit_locations"): + if we.has_extension("unit_locations"): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( @@ -129,7 +129,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) ax3.set_ylabel(None) - if we.is_extension("correlograms"): + if we.has_extension("correlograms"): ax4 = fig.add_subplot(gs[:2, 3]) AutoCorrelogramsWidget( we, @@ -142,7 +142,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_title(None) ax4.set_yticks([]) - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): ax5 = fig.add_subplot(gs[2, :3]) ax6 = fig.add_subplot(gs[2, 3]) axes = np.array([ax5, ax6])