Skip to content

Commit

Permalink
Extend PCA to be able to return sparse projections and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 22, 2023
1 parent 8aa7425 commit 73ceaac
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
16 changes: 10 additions & 6 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,24 +72,26 @@ def _select_extension_data(self, unit_ids):
new_extension_data[k] = v
return new_extension_data

def get_projections(self, unit_id):
def get_projections(self, unit_id, sparse=False):
"""
Returns the computed projections for the sampled waveforms of a unit id.
Parameters
----------
unit_id : int or str
The unit id to return PCA projections for
sparse: bool, default False
If True, and sparsity is not None, only projections on sparse channels are returned.
Returns
-------
proj: np.array
projections: np.array
The PCA projections (num_waveforms, num_components, num_channels).
In case sparsity is used, only the projections on sparse channels are returned.
"""
projections = self._extension_data[f"pca_{unit_id}"]
mode = self._params["mode"]
if mode in ("by_channel_local", "by_channel_global"):
if mode in ("by_channel_local", "by_channel_global") and sparse:
sparsity = self.get_sparsity()
if sparsity is not None:
projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]]
Expand Down Expand Up @@ -141,7 +143,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"):
all_labels = [] #  can be unit_id or unit_index
all_projections = []
for unit_index, unit_id in enumerate(unit_ids):
proj = self.get_projections(unit_id)
proj = self.get_projections(unit_id, sparse=False)
if channel_ids is not None:
chan_inds = self.waveform_extractor.channel_ids_to_indices(channel_ids)
proj = proj[:, :, chan_inds]
Expand All @@ -158,7 +160,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"):

return all_labels, all_projections

def project_new(self, new_waveforms, unit_id=None):
def project_new(self, new_waveforms, unit_id=None, sparse=False):
"""
Projects new waveforms or traces snippets on the PC components.
Expand All @@ -168,6 +170,8 @@ def project_new(self, new_waveforms, unit_id=None):
Array with new waveforms to project with shape (num_waveforms, num_samples, num_channels)
unit_id: int or str
In case PCA is sparse and mode is by_channel_local, the unit_id of 'new_waveforms'
sparse: bool, default: False
If True, and sparsity is not None, only projections on sparse channels are returned.
Returns
-------
Expand Down Expand Up @@ -219,7 +223,7 @@ def project_new(self, new_waveforms, unit_id=None):
projections = pca_model.transform(wfs_flat)

# take care of sparsity (not in case of concatenated)
if mode in ("by_channel_local", "by_channel_global"):
if mode in ("by_channel_local", "by_channel_global") and sparse:
if sparsity is not None:
projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]]
return projections
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,18 @@ def test_sparse(self):
pc.set_params(n_components=5, mode=mode, sparsity=sparsity)
pc.run()
for i, unit_id in enumerate(unit_ids):
proj = pc.get_projections(unit_id)
assert proj.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id]))
proj_sparse = pc.get_projections(unit_id, sparse=True)
assert proj_sparse.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id]))
proj_dense = pc.get_projections(unit_id, sparse=False)
assert proj_dense.shape[1:] == (5, num_channels)

# test project_new
unit_id = 3
new_wfs = we.get_waveforms(unit_id)
new_proj = pc.project_new(new_wfs, unit_id=unit_id)
assert new_proj.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id]))
new_proj_sparse = pc.project_new(new_wfs, unit_id=unit_id, sparse=True)
assert new_proj_sparse.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id]))
new_proj_dense = pc.project_new(new_wfs, unit_id=unit_id, sparse=False)
assert new_proj_dense.shape == (new_wfs.shape[0], 5, num_channels)

if DEBUG:
import matplotlib.pyplot as plt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def test_nn_metrics(self):
we_sparse, metric_names=metric_names, sparsity=None, seed=0, n_jobs=2
)
for metric_name in metrics.columns:
assert np.allclose(metrics[metric_name], metrics_par[metric_name])
# NaNs are skipped
assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna())

def test_recordingless(self):
we = self.we_long
Expand Down Expand Up @@ -305,7 +306,7 @@ def test_empty_units(self):
test.setUp()
# test.test_drift_metrics()
# test.test_extension()
# test.test_nn_metrics()
test.test_nn_metrics()
# test.test_peak_sign()
# test.test_empty_units()
test.test_recordingless()
# test.test_recordingless()

0 comments on commit 73ceaac

Please sign in to comment.