Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add methods to sparsify and densify waveforms to ChannelSparsity #1985

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 95 additions & 11 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

class ChannelSparsity:
"""
Handle channel sparsity for a set of units.
Handle channel sparsity for a set of units. That is, for every unit,
it indicates which channels are used to represent the waveform and the rest
of the non-represented channels are assumed to be zero.

Internally, sparsity is stored as a boolean mask.

Expand Down Expand Up @@ -92,13 +94,17 @@ def __init__(self, mask, unit_ids, channel_ids):
assert self.mask.shape[0] == self.unit_ids.shape[0]
assert self.mask.shape[1] == self.channel_ids.shape[0]

# some precomputed dict
# Those are computed at first call
self._unit_id_to_channel_ids = None
self._unit_id_to_channel_indices = None

self.num_channels = self.channel_ids.size
self.num_units = self.unit_ids.size
self.max_num_active_channels = self.mask.sum(axis=1).max()

def __repr__(self):
ratio = np.mean(self.mask)
txt = f"ChannelSparsity - units: {self.unit_ids.size} - channels: {self.channel_ids.size} - ratio: {ratio:0.2f}"
density = np.mean(self.mask)
txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}"
return txt

@property
Expand All @@ -119,6 +125,84 @@ def unit_id_to_channel_indices(self):
self._unit_id_to_channel_indices[unit_id] = channel_inds
return self._unit_id_to_channel_indices

def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unit_id is not always a str can be int

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be usefull also to have unit_index entry somtimes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I forgot to do this. In fact, mea culpa, I kind of new it when I did it but I ... wished that unit_ids were only string and intentionally forgot : P

I will push a small fix for this.

I would not like to do have a method with two different type of inputs, then 1/4 of the code is used to validate the input with if cases.

"""
Sparsify the waveforms according to a unit_id corresponding sparsity.


Given a unit_id, this method selects only the active channels for
that unit and removes the rest.

Parameters
----------
waveforms : np.array
Dense waveforms with shape (num_units, num_samples, num_channels).
unit_id : str
The unit_id for which to sparsify the waveform.

Returns
-------
sparsified_waveforms : np.array
Sparse waveforms with shape (num_units, num_samples, num_active_channels).

Where num_active_channels is the number of channels that are active for this unit and should be
equal to the number of non-zero elements in the mask for this unit.
"""

assert_msg = (
"Waveforms must be dense to sparsify them. "
f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}"
)
assert self.are_waveforms_dense(waveforms=waveforms), assert_msg

non_zero_indices = self.unit_id_to_channel_indices[unit_id]
sparsified_waveforms = waveforms[..., non_zero_indices]

return sparsified_waveforms

def densify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

"""
Densify sparse waveforms that were sparisified according to a unit's channel sparsity.

Given a unit_id its sparsified waveform, this method places the waveform back
into its original form within a dense array.

Parameters
----------
waveforms : np.array
The sparsified waveforms array of shape (num_units, num_samples, num_active_channels).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is waveforms_or_template, right? It can be 3d (waveforms) or 2d (templates) (as in the tests). Maybe we should change the docstring and variable names accordingly. What do you think @h-mayorquin ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good suggesiton. I think on templates as central / canonical / representative forms of waveforms so I would like to keep the name of the variable waveforms as it is the general type. But I think that the docstring should clarify this and make it explicit.

Copy link
Collaborator Author

@h-mayorquin h-mayorquin Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made some changes to the docstring. Let me know what you think.

If you think that templates should be a variable name it would still prefer to have instead a method called sparsify_template and densify_template that take care of the 2 dimensional case and separate the functions. Maybe that's even better ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok for now! I agree that templates are "mean" waveforms. Thanks for clarifying the docstring!

unit_id : str
The unit_id that was used to sparsify the waveform.

Returns
-------
densified_waveforms : np.array
The densified waveforms array of shape (num_units, num_samples, num_channels).

"""

non_zero_indices = self.unit_id_to_channel_indices[unit_id]

assert_msg = (
"Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is "
f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels."
)
assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg
Comment on lines +187 to +191
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very little detail as a genral partice: here you format the message even before you made the assert.
But if the assert pass there is no need to make the text formatting. And so putting the message after the assert ..., is better no ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do feel that waste of computation but I prefer this for readability. This should not have any really impact on performance anywhere.

As an aside to your aside, it is usually considered a bad practice to use asserts the way that we do here (but most people, doing this is very popular!) because you can dissable assertions in optimize code:

https://discuss.python.org/t/mismatch-between-asserts-semantics-and-how-its-used-o-oo-disable/29282?page=2

The recommended way to do this would be to do and if statement and then raise the appropiate exception but I think that most people does not like this? What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a science vs industry thing. Every piece of science code I've ever read uses assert instead of if-exception. I think with the assumption that scientists aren't going to shut off assertions to increase speed. And using the asserts does empower the user such that if they don't want the protections designed by the developer they can shut them off to get the speed boost. But just my two cents.


densified_shape = waveforms.shape[:-1] + (self.num_channels,)
densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype)
densified_waveforms[..., non_zero_indices] = waveforms

return densified_waveforms

def are_waveforms_dense(self, waveforms: np.ndarray) -> bool:
return waveforms.shape[-1] == self.num_channels

def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str) -> bool:
non_zero_indices = self.unit_id_to_channel_indices[unit_id]
num_active_channels = len(non_zero_indices)
return waveforms.shape[-1] == num_active_channels

@classmethod
def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids):
"""
Expand All @@ -144,16 +228,16 @@ def to_dict(self):
)

@classmethod
def from_dict(cls, d):
def from_dict(cls, dictionary: dict):
unit_id_to_channel_ids_corrected = {}
for unit_id in d["unit_ids"]:
if unit_id in d["unit_id_to_channel_ids"]:
unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][unit_id]
for unit_id in dictionary["unit_ids"]:
if unit_id in dictionary["unit_id_to_channel_ids"]:
unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][unit_id]
else:
unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][str(unit_id)]
d["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected
unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][str(unit_id)]
dictionary["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected

return cls.from_unit_id_to_channel_ids(**d)
return cls.from_unit_id_to_channel_ids(**dictionary)

## Some convinient function to compute sparsity from several strategy
@classmethod
Expand Down
88 changes: 88 additions & 0 deletions src/spikeinterface/core/tests/test_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,93 @@ def test_ChannelSparsity():
assert np.array_equal(sparsity.mask, sparsity4.mask)


def test_sparsify_waveforms():
seed = 0
rng = np.random.default_rng(seed=seed)

num_units = 3
num_samples = 5
num_channels = 4

is_mask_valid = False
while not is_mask_valid:
sparsity_mask = rng.integers(0, 1, size=(num_units, num_channels), endpoint=True, dtype="bool")
is_mask_valid = np.all(sparsity_mask.sum(axis=1) > 0)

unit_ids = np.arange(num_units)
channel_ids = np.arange(num_channels)
sparsity = ChannelSparsity(mask=sparsity_mask, unit_ids=unit_ids, channel_ids=channel_ids)

for unit_id in unit_ids:
waveforms_dense = rng.random(size=(num_units, num_samples, num_channels))

# Test are_waveforms_dense
assert sparsity.are_waveforms_dense(waveforms_dense)

# Test sparsify
waveforms_sparse = sparsity.sparsify_waveforms(waveforms_dense, unit_id=unit_id)
non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id]
num_active_channels = len(non_zero_indices)
assert waveforms_sparse.shape == (num_units, num_samples, num_active_channels)

# Test round-trip (note that this is loosy)
unit_id = unit_ids[unit_id]
non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id]
waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id)
assert np.array_equal(waveforms_dense[..., non_zero_indices], waveforms_dense2[..., non_zero_indices])

# Test sparsify with one waveform (template)
template_dense = waveforms_dense.mean(axis=0)
template_sparse = sparsity.sparsify_waveforms(template_dense, unit_id=unit_id)
assert template_sparse.shape == (num_samples, num_active_channels)

# Test round trip with template
template_dense2 = sparsity.densify_waveforms(template_sparse, unit_id=unit_id)
assert np.array_equal(template_dense[..., non_zero_indices], template_dense2[:, non_zero_indices])


def test_densify_waveforms():
seed = 0
rng = np.random.default_rng(seed=seed)

num_units = 3
num_samples = 5
num_channels = 4

is_mask_valid = False
while not is_mask_valid:
sparsity_mask = rng.integers(0, 1, size=(num_units, num_channels), endpoint=True, dtype="bool")
is_mask_valid = np.all(sparsity_mask.sum(axis=1) > 0)

unit_ids = np.arange(num_units)
channel_ids = np.arange(num_channels)
sparsity = ChannelSparsity(mask=sparsity_mask, unit_ids=unit_ids, channel_ids=channel_ids)

for unit_id in unit_ids:
non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id]
num_active_channels = len(non_zero_indices)
waveforms_sparse = rng.random(size=(num_units, num_samples, num_active_channels))

# Test are waveforms sparse
assert sparsity.are_waveforms_sparse(waveforms_sparse, unit_id=unit_id)

# Test densify
waveforms_dense = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id)
assert waveforms_dense.shape == (num_units, num_samples, num_channels)

# Test round-trip
waveforms_sparse2 = sparsity.sparsify_waveforms(waveforms_dense, unit_id=unit_id)
assert np.array_equal(waveforms_sparse, waveforms_sparse2)

# Test densify with one waveform (template)
template_sparse = waveforms_sparse.mean(axis=0)
template_dense = sparsity.densify_waveforms(template_sparse, unit_id=unit_id)
assert template_dense.shape == (num_samples, num_channels)

# Test round trip with template
template_sparse2 = sparsity.sparsify_waveforms(template_dense, unit_id=unit_id)
assert np.array_equal(template_sparse, template_sparse2)


if __name__ == "__main__":
test_ChannelSparsity()