From 73ceaacefecc4426d994ebca4ca006d667dada42 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 12:06:15 +0200 Subject: [PATCH] Extend PCA to be able to return sparse projections and fix tests --- .../postprocessing/principal_component.py | 16 ++++++++++------ .../tests/test_principal_component.py | 12 ++++++++---- .../tests/test_quality_metric_calculator.py | 7 ++++--- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 5d62216c20..8383dcbb43 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -72,7 +72,7 @@ def _select_extension_data(self, unit_ids): new_extension_data[k] = v return new_extension_data - def get_projections(self, unit_id): + def get_projections(self, unit_id, sparse=False): """ Returns the computed projections for the sampled waveforms of a unit id. @@ -80,16 +80,18 @@ def get_projections(self, unit_id): ---------- unit_id : int or str The unit id to return PCA projections for + sparse: bool, default False + If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- - proj: np.array + projections: np.array The PCA projections (num_waveforms, num_components, num_channels). In case sparsity is used, only the projections on sparse channels are returned. """ projections = self._extension_data[f"pca_{unit_id}"] mode = self._params["mode"] - if mode in ("by_channel_local", "by_channel_global"): + if mode in ("by_channel_local", "by_channel_global") and sparse: sparsity = self.get_sparsity() if sparsity is not None: projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] @@ -141,7 +143,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): all_labels = [] #  can be unit_id or unit_index all_projections = [] for unit_index, unit_id in enumerate(unit_ids): - proj = self.get_projections(unit_id) + proj = self.get_projections(unit_id, sparse=False) if channel_ids is not None: chan_inds = self.waveform_extractor.channel_ids_to_indices(channel_ids) proj = proj[:, :, chan_inds] @@ -158,7 +160,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): return all_labels, all_projections - def project_new(self, new_waveforms, unit_id=None): + def project_new(self, new_waveforms, unit_id=None, sparse=False): """ Projects new waveforms or traces snippets on the PC components. @@ -168,6 +170,8 @@ def project_new(self, new_waveforms, unit_id=None): Array with new waveforms to project with shape (num_waveforms, num_samples, num_channels) unit_id: int or str In case PCA is sparse and mode is by_channel_local, the unit_id of 'new_waveforms' + sparse: bool, default: False + If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- @@ -219,7 +223,7 @@ def project_new(self, new_waveforms, unit_id=None): projections = pca_model.transform(wfs_flat) # take care of sparsity (not in case of concatenated) - if mode in ("by_channel_local", "by_channel_global"): + if mode in ("by_channel_local", "by_channel_global") and sparse: if sparsity is not None: projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] return projections diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 04ce42b70e..49591d9b89 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -86,14 +86,18 @@ def test_sparse(self): pc.set_params(n_components=5, mode=mode, sparsity=sparsity) pc.run() for i, unit_id in enumerate(unit_ids): - proj = pc.get_projections(unit_id) - assert proj.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) + proj_sparse = pc.get_projections(unit_id, sparse=True) + assert proj_sparse.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) + proj_dense = pc.get_projections(unit_id, sparse=False) + assert proj_dense.shape[1:] == (5, num_channels) # test project_new unit_id = 3 new_wfs = we.get_waveforms(unit_id) - new_proj = pc.project_new(new_wfs, unit_id=unit_id) - assert new_proj.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) + new_proj_sparse = pc.project_new(new_wfs, unit_id=unit_id, sparse=True) + assert new_proj_sparse.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) + new_proj_dense = pc.project_new(new_wfs, unit_id=unit_id, sparse=False) + assert new_proj_dense.shape == (new_wfs.shape[0], 5, num_channels) if DEBUG: import matplotlib.pyplot as plt diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 4fa65993d1..977beca210 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -261,7 +261,8 @@ def test_nn_metrics(self): we_sparse, metric_names=metric_names, sparsity=None, seed=0, n_jobs=2 ) for metric_name in metrics.columns: - assert np.allclose(metrics[metric_name], metrics_par[metric_name]) + # NaNs are skipped + assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) def test_recordingless(self): we = self.we_long @@ -305,7 +306,7 @@ def test_empty_units(self): test.setUp() # test.test_drift_metrics() # test.test_extension() - # test.test_nn_metrics() + test.test_nn_metrics() # test.test_peak_sign() # test.test_empty_units() - test.test_recordingless() + # test.test_recordingless()