From 12fd197859a3bb91099e9f5fb73fc5f74f923847 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 19 Sep 2023 12:56:55 +0200 Subject: [PATCH 1/6] Use sparsity mask and handle right border correctly --- .../postprocessing/amplitude_scalings.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..4dab68fdf8 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -90,10 +90,7 @@ def _run(self, **job_kwargs): if self._params["max_dense_channels"] is not None: assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) - sparsity_inds = sparsity.unit_id_to_channel_indices - - # easier to use in chunk function as spikes use unit_index instead o id - unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} + sparsity_mask = sparsity.mask all_templates = we.get_all_templates() # precompute segment slice @@ -113,7 +110,7 @@ def _run(self, **job_kwargs): self.spikes, all_templates, segment_slices, - unit_inds_to_channel_indices, + sparsity_mask, nbefore, nafter, cut_out_before, @@ -262,7 +259,7 @@ def _init_worker_amplitude_scalings( spikes, all_templates, segment_slices, - unit_inds_to_channel_indices, + sparsity_mask, nbefore, nafter, cut_out_before, @@ -282,7 +279,7 @@ def _init_worker_amplitude_scalings( worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after worker_ctx["return_scaled"] = return_scaled - worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["sparsity_mask"] = sparsity_mask worker_ctx["handle_collisions"] = handle_collisions worker_ctx["delta_collision_samples"] = delta_collision_samples @@ -306,7 +303,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) recording = worker_ctx["recording"] all_templates = worker_ctx["all_templates"] segment_slices = worker_ctx["segment_slices"] - unit_inds_to_channel_indices = worker_ctx["unit_inds_to_channel_indices"] + sparsity_mask = worker_ctx["sparsity_mask"] nbefore = worker_ctx["nbefore"] cut_out_before = worker_ctx["cut_out_before"] cut_out_after = worker_ctx["cut_out_after"] @@ -339,7 +336,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( - local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices + local_spikes, local_spikes_w_margin, delta_collision_samples, sparsity_mask ) else: collisions_local = {} @@ -354,7 +351,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] - sparse_indices = unit_inds_to_channel_indices[unit_index] + sparse_indices = sparsity_mask[unit_index] template = all_templates[unit_index][:, sparse_indices] template = template[nbefore - cut_out_before : nbefore + cut_out_after] sample_centered = sample_index - start_frame @@ -393,7 +390,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) right, nbefore, all_templates, - unit_inds_to_channel_indices, + sparsity_mask, cut_out_before, cut_out_after, ) @@ -410,14 +407,14 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) ### Collision handling ### -def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): +def _are_unit_indices_overlapping(sparsity_mask, i, j): """ Returns True if the unit indices i and j are overlapping, False otherwise Parameters ---------- - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices + sparsity_mask: boolean mask + The sparsity mask i: int The first unit index j: int @@ -428,13 +425,13 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): bool True if the unit indices i and j are overlapping, False otherwise """ - if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + if np.sum(np.logical_and(sparsity_mask[i], sparsity_mask[j])) > 0: return True else: return False -def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, sparsity_mask): """ Finds the collisions between spikes. @@ -446,8 +443,8 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ An array of spikes within the added margin delta_collision_samples: int The maximum number of samples between two spikes to consider them as overlapping - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices + sparsity_mask: boolean mask + The sparsity mask Returns ------- @@ -480,7 +477,7 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ # find the overlapping spikes in space as well for possible_overlapping_spike_index in possible_overlapping_spike_indices: if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, + sparsity_mask, spike["unit_index"], spikes_w_margin[possible_overlapping_spike_index]["unit_index"], ): @@ -501,7 +498,7 @@ def fit_collision( right, nbefore, all_templates, - unit_inds_to_channel_indices, + sparsity_mask, cut_out_before, cut_out_after, ): @@ -528,8 +525,8 @@ def fit_collision( The number of samples before the spike to consider for the fit. all_templates: np.ndarray A numpy array of shape (n_units, n_samples, n_channels) containing the templates. - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices. + sparsity_mask: boolean mask + The sparsity mask cut_out_before: int The number of samples to cut out before the spike. cut_out_after: int @@ -547,14 +544,15 @@ def fit_collision( sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity - sparse_indices = np.array([], dtype="int") + sparse_indices = np.zeros(sparsity_mask.shape[1], dtype="int") for spike in collision: - sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] - sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + sparse_indices_i = sparsity_mask[spike["unit_index"]] + sparse_indices = np.logical_or(sparse_indices, sparse_indices_i) local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + num_samples_local_waveform = local_waveform.shape[0] y = local_waveform.T.flatten() X = np.zeros((len(y), len(collision))) @@ -567,8 +565,10 @@ def fit_collision( # deal with borders if sample_centered - cut_out_before < 0: full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] - elif sample_centered + cut_out_after > end_frame + right: - full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + elif sample_centered + cut_out_after > num_samples_local_waveform: + full_template[sample_centered - cut_out_before :] = template_cut[ + : -(cut_out_after + sample_centered - num_samples_local_waveform) + ] else: full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten() From c46a7cba4b1e937d40050d0061017256ab5dade3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 10:31:05 +0200 Subject: [PATCH 2/6] Allow to restrict sparsity --- .../postprocessing/amplitude_scalings.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 4dab68fdf8..3eac333781 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -68,7 +68,6 @@ def _run(self, **job_kwargs): delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) return_scaled = we._params["return_scaled"] - unit_ids = we.unit_ids if ms_before is not None: assert ( @@ -82,9 +81,16 @@ def _run(self, **job_kwargs): cut_out_before = int(ms_before / 1000 * we.sampling_frequency) if ms_before is not None else nbefore cut_out_after = int(ms_after / 1000 * we.sampling_frequency) if ms_after is not None else nafter - if we.is_sparse(): + if we.is_sparse() and self._params["sparsity"] is None: sparsity = we.sparsity - elif self._params["sparsity"] is not None: + elif we.is_sparse() and self._params["sparsity"] is not None: + sparsity = self._params["sparsity"] + # assert provided sparsity is sparser than the one in the waveform extractor + waveform_sparsity = we.sparsity + assert np.all( + np.sum(waveform_sparsity.mask, 1) - np.sum(sparsity.mask, 1) > 0 + ), "The provided sparsity needs to be sparser than the one in the waveform extractor!" + elif not we.is_sparse() and self._params["sparsity"] is not None: sparsity = self._params["sparsity"] else: if self._params["max_dense_channels"] is not None: @@ -362,7 +368,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) template = template[cut_out_before - sample_index :] elif sample_index + cut_out_after > end_frame + right: local_waveform = traces_with_margin[cut_out_start:, sparse_indices] - template = template[: -(sample_index + cut_out_after - end_frame)] + template = template[: -(sample_index + cut_out_after - end_frame - right)] else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape From 4e31329d9aed376ecc41c4238a2f4836f94054ea Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 11:37:18 +0200 Subject: [PATCH 3/6] Add spikes on border when generating sorting, PCA sparse return fixes --- src/spikeinterface/core/generate.py | 28 +++++++++++++++++ .../core/tests/test_generate.py | 30 +++++++++++++++++-- .../postprocessing/amplitude_scalings.py | 12 ++++---- .../postprocessing/principal_component.py | 15 ++++++++-- .../tests/common_extension_tests.py | 26 ++++++++++++++-- .../tests/test_principal_component.py | 12 ++++---- 6 files changed, 104 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..741dd20000 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -123,6 +123,9 @@ def generate_sorting( firing_rates=3.0, empty_units=None, refractory_period_ms=3.0, # in ms + add_spikes_on_borders=False, + num_spikes_per_border=3, + border_size_samples=20, seed=None, ): """ @@ -142,6 +145,12 @@ def generate_sorting( List of units that will have no spikes. (used for testing mainly). refractory_period_ms : float, default: 3.0 The refractory period in ms + add_spikes_on_borders : bool, default: False + If True, spikes will be added close to the borders of the segments. + num_spikes_per_border : int, default: 3 + The number of spikes to add close to the borders of the segments. + border_size_samples : int, default: 20 + The size of the border in samples to add border spikes. seed : int, default: None The random seed @@ -151,11 +160,13 @@ def generate_sorting( The sorting object """ seed = _ensure_seed(seed) + rng = np.random.default_rng(seed) num_segments = len(durations) unit_ids = np.arange(num_units) spikes = [] for segment_index in range(num_segments): + num_samples = int(sampling_frequency * durations[segment_index]) times, labels = synthesize_random_firings( num_units=num_units, sampling_frequency=sampling_frequency, @@ -175,7 +186,23 @@ def generate_sorting( spikes_in_seg["unit_index"] = labels spikes_in_seg["segment_index"] = segment_index spikes.append(spikes_in_seg) + + if add_spikes_on_borders: + spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype) + spikes_on_borders["segment_index"] = segment_index + spikes_on_borders["unit_index"] = rng.choice(num_units, size=2 * num_spikes_per_border, replace=True) + # at start + spikes_on_borders["sample_index"][:num_spikes_per_border] = rng.integers( + 0, border_size_samples, num_spikes_per_border + ) + # at end + spikes_on_borders["sample_index"][num_spikes_per_border:] = rng.integers( + num_samples - border_size_samples, num_samples, num_spikes_per_border + ) + spikes.append(spikes_on_borders) + spikes = np.concatenate(spikes) + spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) @@ -596,6 +623,7 @@ def __init__( dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") + assert strategy in ("tile_pregenerated", "on_the_fly"), "'strategy' must be 'tile_pregenerated' or 'on_the_fly'" BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 9ba5de42d6..3844e421ac 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -26,15 +26,38 @@ def test_generate_recording(): - # TODO even this is extenssivly tested in all other function + # TODO even this is extensively tested in all other functions pass def test_generate_sorting(): - # TODO even this is extenssivly tested in all other function + # TODO even this is extensively tested in all other functions pass +def test_generate_sorting_with_spikes_on_borders(): + num_spikes_on_borders = 10 + border_size_samples = 10 + segment_duration = 10 + for nseg in [1, 2, 3]: + sorting = generate_sorting( + durations=[segment_duration] * nseg, + sampling_frequency=30000, + num_units=10, + add_spikes_on_borders=True, + num_spikes_per_border=num_spikes_on_borders, + border_size_samples=border_size_samples, + ) + spikes = sorting.to_spike_vector(concatenated=False) + # at least num_border spikes at borders for all segments + for i, spikes_in_segment in enumerate(spikes): + num_samples = int(segment_duration * 30000) + assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders + assert ( + np.sum(spikes_in_segment["sample_index"] >= num_samples - border_size_samples) >= num_spikes_on_borders + ) + + def measure_memory_allocation(measure_in_process: bool = True) -> float: """ A local utility to measure memory allocation at a specific point in time. @@ -399,7 +422,7 @@ def test_generate_ground_truth_recording(): if __name__ == "__main__": strategy = "tile_pregenerated" # strategy = "on_the_fly" - test_noise_generator_memory() + # test_noise_generator_memory() # test_noise_generator_under_giga() # test_noise_generator_correct_shape(strategy) # test_noise_generator_consistency_across_calls(strategy, 0, 5) @@ -410,3 +433,4 @@ def test_generate_ground_truth_recording(): # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() + test_generate_sorting_with_spikes_on_borders() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 3eac333781..c86337a30d 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -16,6 +16,7 @@ class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): """ extension_name = "amplitude_scalings" + handle_sparsity = True def __init__(self, waveform_extractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) @@ -357,7 +358,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] - sparse_indices = sparsity_mask[unit_index] + (sparse_indices,) = np.nonzero(sparsity_mask[unit_index]) template = all_templates[unit_index][:, sparse_indices] template = template[nbefore - cut_out_before : nbefore + cut_out_after] sample_centered = sample_index - start_frame @@ -368,7 +369,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) template = template[cut_out_before - sample_index :] elif sample_index + cut_out_after > end_frame + right: local_waveform = traces_with_margin[cut_out_start:, sparse_indices] - template = template[: -(sample_index + cut_out_after - end_frame - right)] + template = template[: -(sample_index + cut_out_after - (end_frame + right))] else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape @@ -550,10 +551,11 @@ def fit_collision( sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity - sparse_indices = np.zeros(sparsity_mask.shape[1], dtype="int") + common_sparse_mask = np.zeros(sparsity_mask.shape[1], dtype="int") for spike in collision: - sparse_indices_i = sparsity_mask[spike["unit_index"]] - sparse_indices = np.logical_or(sparse_indices, sparse_indices_i) + mask_i = sparsity_mask[spike["unit_index"]] + common_sparse_mask = np.logical_or(common_sparse_mask, mask_i) + (sparse_indices,) = np.nonzero(common_sparse_mask) local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 233625e09e..1214b84ac4 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -84,9 +84,16 @@ def get_projections(self, unit_id): Returns ------- proj: np.array - The PCA projections (num_waveforms, num_components, num_channels) + The PCA projections (num_waveforms, num_components, num_channels). + In case sparsity is used, only the projections on sparse channels are returned. """ - return self._extension_data[f"pca_{unit_id}"] + projections = self._extension_data[f"pca_{unit_id}"] + mode = self._params["mode"] + if mode in ("by_channel_local", "by_channel_global"): + sparsity = self.get_sparsity() + if sparsity is not None: + projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] + return projections def get_pca_model(self): """ @@ -211,6 +218,10 @@ def project_new(self, new_waveforms, unit_id=None): wfs_flat = new_waveforms.reshape(new_waveforms.shape[0], -1) projections = pca_model.transform(wfs_flat) + # take care of sparsity (not in case of concatenated) + if mode in ("by_channel_local", "by_channel_global"): + if sparsity is not None: + projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] return projections def get_sparsity(self): diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b9c72f9b99..8657d1dced 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -5,7 +5,7 @@ from pathlib import Path from spikeinterface import extract_waveforms, load_extractor, compute_sparsity -from spikeinterface.extractors import toy_example +from spikeinterface.core.generate import generate_ground_truth_recording if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "postprocessing" @@ -26,7 +26,18 @@ def setUp(self): self.cache_folder = cache_folder # 1-segment - recording, sorting = toy_example(num_segments=1, num_units=10, num_channels=12) + recording, sorting = generate_ground_truth_recording( + durations=[10], + sampling_frequency=30000, + num_channels=12, + num_units=10, + dtype="float32", + seed=91, + generate_sorting_kwargs=dict(add_spikes_on_borders=True), + noise_kwargs=dict(noise_level=10.0, strategy="tile_pregenerated"), + ) + + # add gains and offsets and save gain = 0.1 recording.set_channel_gains(gain) recording.set_channel_offsets(0) @@ -53,7 +64,16 @@ def setUp(self): self.sparsity1 = compute_sparsity(we1, method="radius", radius_um=50) # 2-segments - recording, sorting = toy_example(num_segments=2, num_units=10) + recording, sorting = generate_ground_truth_recording( + durations=[10, 5], + sampling_frequency=30000, + num_channels=12, + num_units=10, + dtype="float32", + seed=91, + generate_sorting_kwargs=dict(add_spikes_on_borders=True), + noise_kwargs=dict(noise_level=10.0, strategy="tile_pregenerated"), + ) recording.set_channel_gains(gain) recording.set_channel_offsets(0) if (cache_folder / "toy_rec_2seg").is_dir(): diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 5d64525b52..04ce42b70e 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -87,13 +87,13 @@ def test_sparse(self): pc.run() for i, unit_id in enumerate(unit_ids): proj = pc.get_projections(unit_id) - assert proj.shape[1:] == (5, 4) + assert proj.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) # test project_new unit_id = 3 new_wfs = we.get_waveforms(unit_id) new_proj = pc.project_new(new_wfs, unit_id=unit_id) - assert new_proj.shape == (new_wfs.shape[0], 5, 4) + assert new_proj.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) if DEBUG: import matplotlib.pyplot as plt @@ -197,8 +197,8 @@ def test_project_new(self): if __name__ == "__main__": test = PrincipalComponentsExtensionTest() test.setUp() - test.test_extension() - test.test_shapes() - test.test_compute_for_all_spikes() + # test.test_extension() + # test.test_shapes() + # test.test_compute_for_all_spikes() test.test_sparse() - test.test_project_new() + # test.test_project_new() From 73ceaacefecc4426d994ebca4ca006d667dada42 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 12:06:15 +0200 Subject: [PATCH 4/6] Extend PCA to be able to return sparse projections and fix tests --- .../postprocessing/principal_component.py | 16 ++++++++++------ .../tests/test_principal_component.py | 12 ++++++++---- .../tests/test_quality_metric_calculator.py | 7 ++++--- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 5d62216c20..8383dcbb43 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -72,7 +72,7 @@ def _select_extension_data(self, unit_ids): new_extension_data[k] = v return new_extension_data - def get_projections(self, unit_id): + def get_projections(self, unit_id, sparse=False): """ Returns the computed projections for the sampled waveforms of a unit id. @@ -80,16 +80,18 @@ def get_projections(self, unit_id): ---------- unit_id : int or str The unit id to return PCA projections for + sparse: bool, default False + If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- - proj: np.array + projections: np.array The PCA projections (num_waveforms, num_components, num_channels). In case sparsity is used, only the projections on sparse channels are returned. """ projections = self._extension_data[f"pca_{unit_id}"] mode = self._params["mode"] - if mode in ("by_channel_local", "by_channel_global"): + if mode in ("by_channel_local", "by_channel_global") and sparse: sparsity = self.get_sparsity() if sparsity is not None: projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] @@ -141,7 +143,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): all_labels = [] #  can be unit_id or unit_index all_projections = [] for unit_index, unit_id in enumerate(unit_ids): - proj = self.get_projections(unit_id) + proj = self.get_projections(unit_id, sparse=False) if channel_ids is not None: chan_inds = self.waveform_extractor.channel_ids_to_indices(channel_ids) proj = proj[:, :, chan_inds] @@ -158,7 +160,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): return all_labels, all_projections - def project_new(self, new_waveforms, unit_id=None): + def project_new(self, new_waveforms, unit_id=None, sparse=False): """ Projects new waveforms or traces snippets on the PC components. @@ -168,6 +170,8 @@ def project_new(self, new_waveforms, unit_id=None): Array with new waveforms to project with shape (num_waveforms, num_samples, num_channels) unit_id: int or str In case PCA is sparse and mode is by_channel_local, the unit_id of 'new_waveforms' + sparse: bool, default: False + If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- @@ -219,7 +223,7 @@ def project_new(self, new_waveforms, unit_id=None): projections = pca_model.transform(wfs_flat) # take care of sparsity (not in case of concatenated) - if mode in ("by_channel_local", "by_channel_global"): + if mode in ("by_channel_local", "by_channel_global") and sparse: if sparsity is not None: projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] return projections diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 04ce42b70e..49591d9b89 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -86,14 +86,18 @@ def test_sparse(self): pc.set_params(n_components=5, mode=mode, sparsity=sparsity) pc.run() for i, unit_id in enumerate(unit_ids): - proj = pc.get_projections(unit_id) - assert proj.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) + proj_sparse = pc.get_projections(unit_id, sparse=True) + assert proj_sparse.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) + proj_dense = pc.get_projections(unit_id, sparse=False) + assert proj_dense.shape[1:] == (5, num_channels) # test project_new unit_id = 3 new_wfs = we.get_waveforms(unit_id) - new_proj = pc.project_new(new_wfs, unit_id=unit_id) - assert new_proj.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) + new_proj_sparse = pc.project_new(new_wfs, unit_id=unit_id, sparse=True) + assert new_proj_sparse.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) + new_proj_dense = pc.project_new(new_wfs, unit_id=unit_id, sparse=False) + assert new_proj_dense.shape == (new_wfs.shape[0], 5, num_channels) if DEBUG: import matplotlib.pyplot as plt diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 4fa65993d1..977beca210 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -261,7 +261,8 @@ def test_nn_metrics(self): we_sparse, metric_names=metric_names, sparsity=None, seed=0, n_jobs=2 ) for metric_name in metrics.columns: - assert np.allclose(metrics[metric_name], metrics_par[metric_name]) + # NaNs are skipped + assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) def test_recordingless(self): we = self.we_long @@ -305,7 +306,7 @@ def test_empty_units(self): test.setUp() # test.test_drift_metrics() # test.test_extension() - # test.test_nn_metrics() + test.test_nn_metrics() # test.test_peak_sign() # test.test_empty_units() - test.test_recordingless() + # test.test_recordingless() From b9b6c15b42a64d877ea9fad9fca84424e2c97edf Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 12:12:21 +0200 Subject: [PATCH 5/6] Add test to check correct order of spikes with borders --- src/spikeinterface/core/tests/test_generate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 3844e421ac..9a9c61766f 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -48,9 +48,15 @@ def test_generate_sorting_with_spikes_on_borders(): num_spikes_per_border=num_spikes_on_borders, border_size_samples=border_size_samples, ) + # check that segments are correctly sorted + all_spikes = sorting.to_spike_vector() + np.testing.assert_array_equal(all_spikes["segment_index"], np.sort(all_spikes["segment_index"])) + spikes = sorting.to_spike_vector(concatenated=False) # at least num_border spikes at borders for all segments - for i, spikes_in_segment in enumerate(spikes): + for spikes_in_segment in spikes: + # check that sample indices are correctly sorted within segments + np.testing.assert_array_equal(spikes_in_segment["sample_index"], np.sort(spikes_in_segment["sample_index"])) num_samples = int(segment_duration * 30000) assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders assert ( From 8e4b43a4f67a92a1497eda5d53f2be2e04f7779f Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 27 Sep 2023 11:37:12 +0200 Subject: [PATCH 6/6] Update src/spikeinterface/postprocessing/amplitude_scalings.py --- src/spikeinterface/postprocessing/amplitude_scalings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 8823fd6257..7e6c95a875 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -431,7 +431,7 @@ def _are_unit_indices_overlapping(sparsity_mask, i, j): bool True if the unit indices i and j are overlapping, False otherwise """ - if np.sum(np.logical_and(sparsity_mask[i], sparsity_mask[j])) > 0: + if np.any(sparsity_mask[i] & sparsity_mask[j]): return True else: return False