Skip to content

Commit

Permalink
Merge branch 'main' into fix-syn-rdm-firing
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 authored Dec 14, 2023
2 parents 815f114 + cd317f8 commit c46ae85
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 32 deletions.
42 changes: 32 additions & 10 deletions src/spikeinterface/widgets/crosscorrelograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ class CrossCorrelogramsWidget(BaseWidget):
The object to compute/get crosscorrelograms from
unit_ids list or None, default: None
List of unit ids
min_similarity_for_correlograms : float, default: 0.2
For sortingview backend. Threshold for computing pair-wise cross-correlograms.
If template similarity between two units is below this threshold, the cross-correlogram is not displayed
window_ms : float, default: 100.0
Window for CCGs in ms
Window for CCGs in ms. If correlograms are already computed (e.g. with WaveformExtractor),
this argument is ignored
bin_ms : float, default: 1.0
Bin size in ms
Bin size in ms. If correlograms are already computed (e.g. with WaveformExtractor),
this argument is ignored
hide_unit_selector : bool, default: False
For sortingview backend, if True the unit selector is not displayed
unit_colors: dict or None, default: None
Expand All @@ -31,18 +36,25 @@ def __init__(
self,
waveform_or_sorting_extractor: Union[WaveformExtractor, BaseSorting],
unit_ids=None,
min_similarity_for_correlograms=0.2,
window_ms=100.0,
bin_ms=1.0,
hide_unit_selector=False,
unit_colors=None,
backend=None,
**backend_kwargs,
):
if min_similarity_for_correlograms is None:
min_similarity_for_correlograms = 0
similarity = None
if isinstance(waveform_or_sorting_extractor, WaveformExtractor):
sorting = waveform_or_sorting_extractor.sorting
self.check_extensions(waveform_or_sorting_extractor, "correlograms")
ccc = waveform_or_sorting_extractor.load_extension("correlograms")
ccgs, bins = ccc.get_data()
if min_similarity_for_correlograms > 0:
self.check_extensions(waveform_or_sorting_extractor, "similarity")
similarity = waveform_or_sorting_extractor.load_extension("similarity").get_data()
else:
sorting = waveform_or_sorting_extractor
ccgs, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms)
Expand All @@ -53,10 +65,14 @@ def __init__(
else:
unit_indices = sorting.ids_to_indices(unit_ids)
correlograms = ccgs[unit_indices][:, unit_indices]
if similarity is not None:
similarity = similarity[unit_indices][:, unit_indices]

plot_data = dict(
correlograms=correlograms,
bins=bins,
similarity=similarity,
min_similarity_for_correlograms=min_similarity_for_correlograms,
unit_ids=unit_ids,
hide_unit_selector=hide_unit_selector,
unit_colors=unit_colors,
Expand Down Expand Up @@ -100,23 +116,29 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url
from .utils_sortingview import make_serializable, handle_display_and_url

dp = to_attr(data_plot)

unit_ids = make_serializable(dp.unit_ids)

if dp.similarity is not None:
similarity = dp.similarity
else:
similarity = np.ones((len(unit_ids), len(unit_ids)))

cc_items = []
for i in range(len(unit_ids)):
for j in range(i, len(unit_ids)):
cc_items.append(
vv.CrossCorrelogramItem(
unit_id1=unit_ids[i],
unit_id2=unit_ids[j],
bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
bin_counts=dp.correlograms[i, j].astype("int32"),
if similarity[i, j] >= dp.min_similarity_for_correlograms:
cc_items.append(
vv.CrossCorrelogramItem(
unit_id1=unit_ids[i],
unit_id2=unit_ids[j],
bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
bin_counts=dp.correlograms[i, j].astype("int32"),
)
)
)

self.view = vv.CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector)

Expand Down
17 changes: 16 additions & 1 deletion src/spikeinterface/widgets/sorting_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,19 @@ class SortingSummaryWidget(BaseWidget):
max_amplitudes_per_unit : int or None, default: None
Maximum number of spikes per unit for plotting amplitudes.
If None, all spikes are plotted
min_similarity_for_correlograms : float, default: 0.2
Threshold for computing pair-wise cross-correlograms. If template similarity between two units
is below this threshold, the cross-correlogram is not computed
(sortingview backend)
curation : bool, default: False
If True, manual curation is enabled
(sortingview backend)
unit_table_properties : list or None, default: None
List of properties to be added to the unit table
(sortingview backend)
label_choices : list or None, default: None
List of labels to be added to the curation table
(sortingview backend)
unit_table_properties : list or None, default: None
List of properties to be added to the unit table
(sortingview backend)
Expand All @@ -46,6 +52,7 @@ def __init__(
unit_ids=None,
sparsity=None,
max_amplitudes_per_unit=None,
min_similarity_for_correlograms=0.2,
curation=False,
unit_table_properties=None,
label_choices=None,
Expand All @@ -63,6 +70,7 @@ def __init__(
waveform_extractor=waveform_extractor,
unit_ids=unit_ids,
sparsity=sparsity,
min_similarity_for_correlograms=min_similarity_for_correlograms,
unit_table_properties=unit_table_properties,
curation=curation,
label_choices=label_choices,
Expand All @@ -79,6 +87,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
we = dp.waveform_extractor
unit_ids = dp.unit_ids
sparsity = dp.sparsity
min_similarity_for_correlograms = dp.min_similarity_for_correlograms

unit_ids = make_serializable(dp.unit_ids)

Expand All @@ -101,7 +110,13 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
backend="sortingview",
).view
v_cross_correlograms = CrossCorrelogramsWidget(
we, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, display=False, backend="sortingview"
we,
unit_ids=unit_ids,
min_similarity_for_correlograms=min_similarity_for_correlograms,
hide_unit_selector=True,
generate_url=False,
display=False,
backend="sortingview",
).view

v_unit_locations = UnitLocationsWidget(
Expand Down
84 changes: 70 additions & 14 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def setUpClass(cls):

# make sparse waveforms
cls.sparsity_radius = compute_sparsity(cls.we_dense, method="radius", radius_um=50)
cls.sparsity_strict = compute_sparsity(cls.we_dense, method="radius", radius_um=20)
cls.sparsity_large = compute_sparsity(cls.we_dense, method="radius", radius_um=80)
cls.sparsity_best = compute_sparsity(cls.we_dense, method="best_channels", num_channels=5)
if (cache_folder / "we_sparse").is_dir():
cls.we_sparse = load_waveforms(cache_folder / "we_sparse")
Expand Down Expand Up @@ -194,71 +196,105 @@ def test_plot_unit_waveforms(self):
sw.plot_unit_waveforms(
self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]
)
# extra sparsity
sw.plot_unit_waveforms(
self.we_sparse,
sparsity=self.sparsity_strict,
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
# test "larger" sparsity
with self.assertRaises(AssertionError):
sw.plot_unit_waveforms(
self.we_sparse,
sparsity=self.sparsity_large,
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)

def test_plot_unit_templates(self):
possible_backends = list(sw.UnitTemplatesWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
print(f"Testing backend {backend}")
sw.plot_unit_templates(
self.we_dense, backend=backend, templates_percentile_shading=None, **self.backend_kwargs[backend]
)
print("Dense")
sw.plot_unit_templates(self.we_dense, backend=backend, **self.backend_kwargs[backend])
unit_ids = self.sorting.unit_ids[:6]
print("Dense + radius")
sw.plot_unit_templates(
self.we_dense,
sparsity=self.sparsity_radius,
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
print("Dense + best")
sw.plot_unit_templates(
self.we_dense,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
# test different shadings
print("Sparse")
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=None,
backend=backend,
**self.backend_kwargs[backend],
)
print("Sparse2")
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=None,
# templates_percentile_shading=None,
scale=10,
backend=backend,
**self.backend_kwargs[backend],
)
# test different shadings
print("Sparse3")
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
backend=backend,
templates_percentile_shading=None,
shade_templates=False,
**self.backend_kwargs[backend],
)
print("Sparse4")
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=5,
templates_percentile_shading=0.1,
backend=backend,
**self.backend_kwargs[backend],
)
print("Extra sparsity")
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
sparsity=self.sparsity_strict,
unit_ids=unit_ids,
templates_percentile_shading=[10, 90],
templates_percentile_shading=[1, 10, 90, 99],
backend=backend,
**self.backend_kwargs[backend],
)
# test "larger" sparsity
with self.assertRaises(AssertionError):
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_large,
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
if backend != "sortingview":
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=[1, 5, 25, 75, 95, 99],
backend=backend,
Expand All @@ -269,7 +305,6 @@ def test_plot_unit_templates(self):
with self.assertRaises(AssertionError):
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=[1, 5, 25, 75, 95, 99],
backend=backend,
Expand Down Expand Up @@ -331,6 +366,13 @@ def test_plot_crosscorrelogram(self):
possible_backends = list(sw.CrossCorrelogramsWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_crosscorrelograms(
self.sorting,
window_ms=500.0,
bin_ms=20.0,
backend=backend,
**self.backend_kwargs[backend],
)
unit_ids = self.sorting.unit_ids[:4]
sw.plot_crosscorrelograms(
self.sorting,
Expand All @@ -340,6 +382,17 @@ def test_plot_crosscorrelogram(self):
backend=backend,
**self.backend_kwargs[backend],
)
sw.plot_crosscorrelograms(
self.we_sparse,
backend=backend,
**self.backend_kwargs[backend],
)
sw.plot_crosscorrelograms(
self.we_sparse,
min_similarity_for_correlograms=0.6,
backend=backend,
**self.backend_kwargs[backend],
)

def test_plot_isi_distribution(self):
possible_backends = list(sw.ISIDistributionWidget.get_possible_backends())
Expand Down Expand Up @@ -456,6 +509,9 @@ def test_plot_sorting_summary(self):
if backend not in self.skip_backends:
sw.plot_sorting_summary(self.we_dense, backend=backend, **self.backend_kwargs[backend])
sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend])
sw.plot_sorting_summary(
self.we_sparse, sparsity=self.sparsity_strict, backend=backend, **self.backend_kwargs[backend]
)

def test_plot_agreement_matrix(self):
possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends())
Expand Down Expand Up @@ -529,7 +585,7 @@ def test_plot_multicomparison(self):
# mytest.test_plot_traces()
# mytest.test_plot_unit_waveforms()
# mytest.test_plot_unit_templates()
mytest.test_plot_unit_templates()
mytest.test_plot_unit_waveforms()
# mytest.test_plot_unit_depths()
# mytest.test_plot_unit_templates()
# mytest.test_plot_unit_summary()
Expand Down
34 changes: 27 additions & 7 deletions src/spikeinterface/widgets/unit_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,26 @@ def __init__(
sorting: BaseSorting = we.sorting

if unit_ids is None:
unit_ids = sorting.get_unit_ids()
unit_ids = unit_ids
unit_ids = sorting.unit_ids
if channel_ids is None:
channel_ids = we.channel_ids

if unit_colors is None:
unit_colors = get_unit_colors(sorting)

channel_locations = we.get_channel_locations()[we.channel_ids_to_indices(channel_ids)]

extra_sparsity = False
if waveform_extractor.is_sparse():
sparsity = waveform_extractor.sparsity
if sparsity is None:
sparsity = waveform_extractor.sparsity
else:
# assert provided sparsity is a subset of waveform sparsity
combined_mask = np.logical_or(we.sparsity.mask, sparsity.mask)
assert np.all(np.sum(combined_mask, 1) - np.sum(we.sparsity.mask, 1) == 0), (
"The provided 'sparsity' needs to include only the sparse channels "
"used to extract waveforms (for example, by using a smaller 'radius_um')."
)
extra_sparsity = True
else:
if sparsity is None:
# in this case, we construct a dense sparsity
Expand All @@ -139,10 +147,22 @@ def __init__(
wfs_by_ids = {}
if plot_waveforms:
for unit_id in unit_ids:
if waveform_extractor.is_sparse():
wfs = we.get_waveforms(unit_id)
if not extra_sparsity:
if waveform_extractor.is_sparse():
wfs = we.get_waveforms(unit_id)
else:
wfs = we.get_waveforms(unit_id, sparsity=sparsity)
else:
wfs = we.get_waveforms(unit_id, sparsity=sparsity)
# in this case we have to slice the waveform sparsity based on the extra sparsity
unit_index = list(sorting.unit_ids).index(unit_id)
# first get the sparse waveforms
wfs = we.get_waveforms(unit_id)
# find additional slice to apply to sparse waveforms
(wfs_sparse_indices,) = np.nonzero(waveform_extractor.sparsity.mask[unit_index])
(extra_sparse_indices,) = np.nonzero(sparsity.mask[unit_index])
(extra_slice,) = np.nonzero(np.isin(wfs_sparse_indices, extra_sparse_indices))
# apply extra sparsity
wfs = wfs[:, :, extra_slice]
wfs_by_ids[unit_id] = wfs

plot_data = dict(
Expand Down

0 comments on commit c46ae85

Please sign in to comment.