-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add methods to sparsify and densify waveforms to ChannelSparsity
#1985
Changes from 15 commits
79490e2
75bf8e1
aa52484
99fb18d
f7e83e1
3f1a043
7085e77
de14163
dff698f
b100cff
5ffbb7f
f9640ee
1ec93b5
54d92a2
a7dc63a
646679f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,9 @@ | |
|
||
class ChannelSparsity: | ||
""" | ||
Handle channel sparsity for a set of units. | ||
Handle channel sparsity for a set of units. That is, for every unit, | ||
it indicates which channels are used to represent the waveform and the rest | ||
of the non-represented channels are assumed to be zero. | ||
|
||
Internally, sparsity is stored as a boolean mask. | ||
|
||
|
@@ -92,13 +94,17 @@ def __init__(self, mask, unit_ids, channel_ids): | |
assert self.mask.shape[0] == self.unit_ids.shape[0] | ||
assert self.mask.shape[1] == self.channel_ids.shape[0] | ||
|
||
# some precomputed dict | ||
# Those are computed at first call | ||
self._unit_id_to_channel_ids = None | ||
self._unit_id_to_channel_indices = None | ||
|
||
self.num_channels = self.channel_ids.size | ||
self.num_units = self.unit_ids.size | ||
self.max_num_active_channels = self.mask.sum(axis=1).max() | ||
|
||
def __repr__(self): | ||
ratio = np.mean(self.mask) | ||
txt = f"ChannelSparsity - units: {self.unit_ids.size} - channels: {self.channel_ids.size} - ratio: {ratio:0.2f}" | ||
density = np.mean(self.mask) | ||
txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}" | ||
return txt | ||
|
||
@property | ||
|
@@ -119,6 +125,84 @@ def unit_id_to_channel_indices(self): | |
self._unit_id_to_channel_indices[unit_id] = channel_inds | ||
return self._unit_id_to_channel_indices | ||
|
||
def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray: | ||
""" | ||
Sparsify the waveforms according to a unit_id corresponding sparsity. | ||
|
||
|
||
Given a unit_id, this method selects only the active channels for | ||
that unit and removes the rest. | ||
|
||
Parameters | ||
---------- | ||
waveforms : np.array | ||
Dense waveforms with shape (num_units, num_samples, num_channels). | ||
unit_id : str | ||
The unit_id for which to sparsify the waveform. | ||
|
||
Returns | ||
------- | ||
sparsified_waveforms : np.array | ||
Sparse waveforms with shape (num_units, num_samples, num_active_channels). | ||
|
||
Where num_active_channels is the number of channels that are active for this unit and should be | ||
equal to the number of non-zero elements in the mask for this unit. | ||
""" | ||
|
||
assert_msg = ( | ||
"Waveforms must be dense to sparsify them. " | ||
f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}" | ||
) | ||
assert self.are_waveforms_dense(waveforms=waveforms), assert_msg | ||
|
||
non_zero_indices = self.unit_id_to_channel_indices[unit_id] | ||
sparsified_waveforms = waveforms[..., non_zero_indices] | ||
|
||
return sparsified_waveforms | ||
|
||
def densify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
""" | ||
Densify sparse waveforms that were sparisified according to a unit's channel sparsity. | ||
|
||
Given a unit_id its sparsified waveform, this method places the waveform back | ||
into its original form within a dense array. | ||
|
||
Parameters | ||
---------- | ||
waveforms : np.array | ||
The sparsified waveforms array of shape (num_units, num_samples, num_active_channels). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a good suggesiton. I think on templates as central / canonical / representative forms of waveforms so I would like to keep the name of the variable waveforms as it is the general type. But I think that the docstring should clarify this and make it explicit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made some changes to the docstring. Let me know what you think. If you think that templates should be a variable name it would still prefer to have instead a method called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's ok for now! I agree that templates are "mean" waveforms. Thanks for clarifying the docstring! |
||
unit_id : str | ||
The unit_id that was used to sparsify the waveform. | ||
|
||
Returns | ||
------- | ||
densified_waveforms : np.array | ||
The densified waveforms array of shape (num_units, num_samples, num_channels). | ||
|
||
""" | ||
|
||
non_zero_indices = self.unit_id_to_channel_indices[unit_id] | ||
|
||
assert_msg = ( | ||
"Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " | ||
f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels." | ||
) | ||
assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg | ||
Comment on lines
+187
to
+191
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very little detail as a genral partice: here you format the message even before you made the assert. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do feel that waste of computation but I prefer this for readability. This should not have any really impact on performance anywhere. As an aside to your aside, it is usually considered a bad practice to use asserts the way that we do here (but most people, doing this is very popular!) because you can dissable assertions in optimize code: The recommended way to do this would be to do and if statement and then raise the appropiate exception but I think that most people does not like this? What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a science vs industry thing. Every piece of science code I've ever read uses assert instead of if-exception. I think with the assumption that scientists aren't going to shut off assertions to increase speed. And using the asserts does empower the user such that if they don't want the protections designed by the developer they can shut them off to get the speed boost. But just my two cents. |
||
|
||
densified_shape = waveforms.shape[:-1] + (self.num_channels,) | ||
densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype) | ||
densified_waveforms[..., non_zero_indices] = waveforms | ||
|
||
return densified_waveforms | ||
|
||
def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: | ||
return waveforms.shape[-1] == self.num_channels | ||
|
||
def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str) -> bool: | ||
non_zero_indices = self.unit_id_to_channel_indices[unit_id] | ||
num_active_channels = len(non_zero_indices) | ||
return waveforms.shape[-1] == num_active_channels | ||
|
||
@classmethod | ||
def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): | ||
""" | ||
|
@@ -144,16 +228,16 @@ def to_dict(self): | |
) | ||
|
||
@classmethod | ||
def from_dict(cls, d): | ||
def from_dict(cls, dictionary: dict): | ||
unit_id_to_channel_ids_corrected = {} | ||
for unit_id in d["unit_ids"]: | ||
if unit_id in d["unit_id_to_channel_ids"]: | ||
unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][unit_id] | ||
for unit_id in dictionary["unit_ids"]: | ||
if unit_id in dictionary["unit_id_to_channel_ids"]: | ||
unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][unit_id] | ||
else: | ||
unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][str(unit_id)] | ||
d["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected | ||
unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][str(unit_id)] | ||
dictionary["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected | ||
|
||
return cls.from_unit_id_to_channel_ids(**d) | ||
return cls.from_unit_id_to_channel_ids(**dictionary) | ||
|
||
## Some convinient function to compute sparsity from several strategy | ||
@classmethod | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unit_id is not always a str can be int
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be usefull also to have unit_index entry somtimes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I forgot to do this. In fact, mea culpa, I kind of new it when I did it but I ... wished that unit_ids were only string and intentionally forgot : P
I will push a small fix for this.
I would not like to do have a method with two different type of inputs, then 1/4 of the code is used to validate the input with if cases.