Skip to content

Commit

Permalink
Merge branch 'main' into add_matlab_documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Sep 21, 2023
2 parents 9ba6fc6 + d78cb07 commit add9f98
Show file tree
Hide file tree
Showing 24 changed files with 793 additions and 925 deletions.
Binary file added doc/images/plot_traces_ephyviewer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
42 changes: 41 additions & 1 deletion doc/modules/widgets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Since version 0.95.0, the :py:mod:`spikeinterface.widgets` module supports multi
* | :code:`sortingview`: web-based and interactive rendering using the `sortingview <https://github.com/magland/sortingview>`_
| and `FIGURL <https://github.com/flatironinstitute/figurl>`_ packages.
Version 0.100.0, also come with this new backend:
* | :code:`ephyviewer`: interactive Qt based using the `ephyviewer <https://ephyviewer.readthedocs.io/en/latest/>`_ package


Installing backends
-------------------
Expand Down Expand Up @@ -85,6 +88,28 @@ Finally, if you wish to set up another cloud provider, follow the instruction fr
`kachery-cloud <https://github.com/flatironinstitute/kachery-cloud>`_ package ("Using your own storage bucket").


ephyviewer
^^^^^^^^^^

This backend is Qt based with PyQt5, PyQt6 or PySide6 support. Qt is sometimes tedious to install.


For a pip-based installation, run:

.. code-block:: bash
pip install PySide6 ephyviewer
Anaconda users will have a better experience with this:

.. code-block:: bash
conda install pyqt=5
pip install ephyviewer
Usage
-----

Expand Down Expand Up @@ -215,6 +240,21 @@ For example, here is how to combine the timeseries and sorting summary generated
print(url)
ephyviewer
^^^^^^^^^^


The :code:`ephyviewer` backend is currently only available for the :py:func:`~spikeinterface.widgets.plot_traces()` function.


.. code-block:: python
plot_traces(recording, backend="ephyviewer", mode="line", show_channel_ids=True)
.. image:: ../images/plot_traces_ephyviewer.png



Available plotting functions
----------------------------
Expand All @@ -229,7 +269,7 @@ Available plotting functions
* :py:func:`~spikeinterface.widgets.plot_spikes_on_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`)
* :py:func:`~spikeinterface.widgets.plot_template_metrics` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`)
* :py:func:`~spikeinterface.widgets.plot_template_similarity` (backends: ::code:`matplotlib`, :code:`sortingview`)
* :py:func:`~spikeinterface.widgets.plot_timeseries` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`)
* :py:func:`~spikeinterface.widgets.plot_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`, :code:`ephyviewer`)
* :py:func:`~spikeinterface.widgets.plot_unit_depths` (backends: :code:`matplotlib`)
* :py:func:`~spikeinterface.widgets.plot_unit_locations` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`)
* :py:func:`~spikeinterface.widgets.plot_unit_summary` (backends: :code:`matplotlib`)
Expand Down
109 changes: 98 additions & 11 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np

from .recording_tools import get_channel_distances, get_noise_levels
Expand Down Expand Up @@ -33,7 +35,9 @@

class ChannelSparsity:
"""
Handle channel sparsity for a set of units.
Handle channel sparsity for a set of units. That is, for every unit,
it indicates which channels are used to represent the waveform and the rest
of the non-represented channels are assumed to be zero.
Internally, sparsity is stored as a boolean mask.
Expand Down Expand Up @@ -92,13 +96,17 @@ def __init__(self, mask, unit_ids, channel_ids):
assert self.mask.shape[0] == self.unit_ids.shape[0]
assert self.mask.shape[1] == self.channel_ids.shape[0]

# some precomputed dict
# Those are computed at first call
self._unit_id_to_channel_ids = None
self._unit_id_to_channel_indices = None

self.num_channels = self.channel_ids.size
self.num_units = self.unit_ids.size
self.max_num_active_channels = self.mask.sum(axis=1).max()

def __repr__(self):
ratio = np.mean(self.mask)
txt = f"ChannelSparsity - units: {self.unit_ids.size} - channels: {self.channel_ids.size} - ratio: {ratio:0.2f}"
density = np.mean(self.mask)
txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}"
return txt

@property
Expand All @@ -119,6 +127,85 @@ def unit_id_to_channel_indices(self):
self._unit_id_to_channel_indices[unit_id] = channel_inds
return self._unit_id_to_channel_indices

def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray:
"""
Sparsify the waveforms according to a unit_id corresponding sparsity.
Given a unit_id, this method selects only the active channels for
that unit and removes the rest.
Parameters
----------
waveforms : np.array
Dense waveforms with shape (num_waveforms, num_samples, num_channels) or a
single dense waveform (template) with shape (num_samples, num_channels).
unit_id : str
The unit_id for which to sparsify the waveform.
Returns
-------
sparsified_waveforms : np.array
Sparse waveforms with shape (num_waveforms, num_samples, num_active_channels)
or a single sparsified waveform (template) with shape (num_samples, num_active_channels).
"""

assert_msg = (
"Waveforms must be dense to sparsify them. "
f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}"
)
assert self.are_waveforms_dense(waveforms=waveforms), assert_msg

non_zero_indices = self.unit_id_to_channel_indices[unit_id]
sparsified_waveforms = waveforms[..., non_zero_indices]

return sparsified_waveforms

def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray:
"""
Densify sparse waveforms that were sparisified according to a unit's channel sparsity.
Given a unit_id its sparsified waveform, this method places the waveform back
into its original form within a dense array.
Parameters
----------
waveforms : np.array
The sparsified waveforms array of shape (num_waveforms, num_samples, num_active_channels) or a single
sparse waveform (template) with shape (num_samples, num_active_channels).
unit_id : str
The unit_id that was used to sparsify the waveform.
Returns
-------
densified_waveforms : np.array
The densified waveforms array of shape (num_waveforms, num_samples, num_channels) or a single dense
waveform (template) with shape (num_samples, num_channels).
"""

non_zero_indices = self.unit_id_to_channel_indices[unit_id]

assert_msg = (
"Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is "
f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels."
)
assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg

densified_shape = waveforms.shape[:-1] + (self.num_channels,)
densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype)
densified_waveforms[..., non_zero_indices] = waveforms

return densified_waveforms

def are_waveforms_dense(self, waveforms: np.ndarray) -> bool:
return waveforms.shape[-1] == self.num_channels

def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> bool:
non_zero_indices = self.unit_id_to_channel_indices[unit_id]
num_active_channels = len(non_zero_indices)
return waveforms.shape[-1] == num_active_channels

@classmethod
def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids):
"""
Expand All @@ -144,16 +231,16 @@ def to_dict(self):
)

@classmethod
def from_dict(cls, d):
def from_dict(cls, dictionary: dict):
unit_id_to_channel_ids_corrected = {}
for unit_id in d["unit_ids"]:
if unit_id in d["unit_id_to_channel_ids"]:
unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][unit_id]
for unit_id in dictionary["unit_ids"]:
if unit_id in dictionary["unit_id_to_channel_ids"]:
unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][unit_id]
else:
unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][str(unit_id)]
d["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected
unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][str(unit_id)]
dictionary["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected

return cls.from_unit_id_to_channel_ids(**d)
return cls.from_unit_id_to_channel_ids(**dictionary)

## Some convinient function to compute sparsity from several strategy
@classmethod
Expand Down
88 changes: 88 additions & 0 deletions src/spikeinterface/core/tests/test_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,93 @@ def test_ChannelSparsity():
assert np.array_equal(sparsity.mask, sparsity4.mask)


def test_sparsify_waveforms():
seed = 0
rng = np.random.default_rng(seed=seed)

num_units = 3
num_samples = 5
num_channels = 4

is_mask_valid = False
while not is_mask_valid:
sparsity_mask = rng.integers(0, 1, size=(num_units, num_channels), endpoint=True, dtype="bool")
is_mask_valid = np.all(sparsity_mask.sum(axis=1) > 0)

unit_ids = np.arange(num_units)
channel_ids = np.arange(num_channels)
sparsity = ChannelSparsity(mask=sparsity_mask, unit_ids=unit_ids, channel_ids=channel_ids)

for unit_id in unit_ids:
waveforms_dense = rng.random(size=(num_units, num_samples, num_channels))

# Test are_waveforms_dense
assert sparsity.are_waveforms_dense(waveforms_dense)

# Test sparsify
waveforms_sparse = sparsity.sparsify_waveforms(waveforms_dense, unit_id=unit_id)
non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id]
num_active_channels = len(non_zero_indices)
assert waveforms_sparse.shape == (num_units, num_samples, num_active_channels)

# Test round-trip (note that this is loosy)
unit_id = unit_ids[unit_id]
non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id]
waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id)
assert np.array_equal(waveforms_dense[..., non_zero_indices], waveforms_dense2[..., non_zero_indices])

# Test sparsify with one waveform (template)
template_dense = waveforms_dense.mean(axis=0)
template_sparse = sparsity.sparsify_waveforms(template_dense, unit_id=unit_id)
assert template_sparse.shape == (num_samples, num_active_channels)

# Test round trip with template
template_dense2 = sparsity.densify_waveforms(template_sparse, unit_id=unit_id)
assert np.array_equal(template_dense[..., non_zero_indices], template_dense2[:, non_zero_indices])


def test_densify_waveforms():
seed = 0
rng = np.random.default_rng(seed=seed)

num_units = 3
num_samples = 5
num_channels = 4

is_mask_valid = False
while not is_mask_valid:
sparsity_mask = rng.integers(0, 1, size=(num_units, num_channels), endpoint=True, dtype="bool")
is_mask_valid = np.all(sparsity_mask.sum(axis=1) > 0)

unit_ids = np.arange(num_units)
channel_ids = np.arange(num_channels)
sparsity = ChannelSparsity(mask=sparsity_mask, unit_ids=unit_ids, channel_ids=channel_ids)

for unit_id in unit_ids:
non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id]
num_active_channels = len(non_zero_indices)
waveforms_sparse = rng.random(size=(num_units, num_samples, num_active_channels))

# Test are waveforms sparse
assert sparsity.are_waveforms_sparse(waveforms_sparse, unit_id=unit_id)

# Test densify
waveforms_dense = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id)
assert waveforms_dense.shape == (num_units, num_samples, num_channels)

# Test round-trip
waveforms_sparse2 = sparsity.sparsify_waveforms(waveforms_dense, unit_id=unit_id)
assert np.array_equal(waveforms_sparse, waveforms_sparse2)

# Test densify with one waveform (template)
template_sparse = waveforms_sparse.mean(axis=0)
template_dense = sparsity.densify_waveforms(template_sparse, unit_id=unit_id)
assert template_dense.shape == (num_samples, num_channels)

# Test round trip with template
template_sparse2 = sparsity.sparsify_waveforms(template_dense, unit_id=unit_id)
assert np.array_equal(template_sparse, template_sparse2)


if __name__ == "__main__":
test_ChannelSparsity()
Loading

0 comments on commit add9f98

Please sign in to comment.