Skip to content

Commit

Permalink
Merge pull request #3354 from alejoe91/fix-widget-tests
Browse files Browse the repository at this point in the history
Fix widgets tests and add test on unit_table_properties
  • Loading branch information
samuelgarcia authored Sep 3, 2024
2 parents cb3da4e + 3b9922e commit 0f6f21b
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/all-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ jobs:
shell: bash
if: env.RUN_WIDGETS_TESTS == 'true'
run: |
pip install -e .[full]
pip install -e .[full,widgets]
./.github/run_tests.sh widgets --no-virtual-env
- name: Test exporters
Expand Down
15 changes: 15 additions & 0 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,21 @@ def test_plot_sorting_summary(self):
backend=backend,
**self.backend_kwargs[backend],
)
# add unit_properties
sw.plot_sorting_summary(
self.sorting_analyzer_sparse,
unit_table_properties=["firing_rate", "snr"],
backend=backend,
**self.backend_kwargs[backend],
)
# adding a missing property should raise a warning
with self.assertWarns(UserWarning):
sw.plot_sorting_summary(
self.sorting_analyzer_sparse,
unit_table_properties=["missing_property"],
backend=backend,
**self.backend_kwargs[backend],
)

def test_plot_agreement_matrix(self):
possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends())
Expand Down
48 changes: 31 additions & 17 deletions src/spikeinterface/widgets/utils_sortingview.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from ..core import SortingAnalyzer, BaseSorting
from ..core.core_tools import check_json
from warnings import warn

Expand Down Expand Up @@ -46,26 +47,42 @@ def handle_display_and_url(widget, view, **backend_kwargs):
return url


def generate_unit_table_view(analyzer, unit_properties=None, similarity_scores=None):
def generate_unit_table_view(
sorting_or_sorting_analyzer: SortingAnalyzer | BaseSorting,
unit_properties: list[str] | None = None,
similarity_scores: npndarray | None = None,
):
import sortingview.views as vv

sorting = analyzer.sorting
if isinstance(sorting_or_sorting_analyzer, SortingAnalyzer):
analyzer = sorting_or_sorting_analyzer
sorting = analyzer.sorting
else:
sorting = sorting_or_sorting_analyzer
analyzer = None

# Find available unit properties from all sources
sorting_props = list(sorting.get_property_keys())
if analyzer.get_extension("quality_metrics") is not None:
qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns)
qm_data = analyzer.get_extension("quality_metrics").get_data()
if analyzer is not None:
if analyzer.get_extension("quality_metrics") is not None:
qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns)
qm_data = analyzer.get_extension("quality_metrics").get_data()
else:
qm_props = []
if analyzer.get_extension("template_metrics") is not None:
tm_props = list(analyzer.get_extension("template_metrics").get_data().columns)
tm_data = analyzer.get_extension("template_metrics").get_data()
else:
tm_props = []
# Check for any overlaps and warn user if any
all_props = sorting_props + qm_props + tm_props
else:
all_props = sorting_props
qm_props = []
if analyzer.get_extension("template_metrics") is not None:
tm_props = list(analyzer.get_extension("template_metrics").get_data().columns)
tm_data = analyzer.get_extension("template_metrics").get_data()
else:
tm_props = []
qm_data = None
tm_data = None

# Check for any overlaps and warn user if any
all_props = sorting_props + qm_props + tm_props
overlap_props = [prop for prop in all_props if all_props.count(prop) > 1]
if len(overlap_props) > 0:
warn(
Expand Down Expand Up @@ -93,7 +110,8 @@ def generate_unit_table_view(analyzer, unit_properties=None, similarity_scores=N
elif prop_name in tm_props:
property_values = tm_data[prop_name].values
else:
raise ValueError(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics")
warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics")
continue

# make dtype available
val0 = np.array(property_values[0])
Expand All @@ -106,7 +124,7 @@ def generate_unit_table_view(analyzer, unit_properties=None, similarity_scores=N
elif val0.dtype.kind == "b":
dtype = "bool"
else:
print(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping")
warn(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping")
continue
ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype))
valid_unit_properties.append(prop_name)
Expand All @@ -122,10 +140,6 @@ def generate_unit_table_view(analyzer, unit_properties=None, similarity_scores=N
property_values = qm_data[prop_name].values
elif prop_name in tm_props:
property_values = tm_data[prop_name].values
else:
raise ValueError(
f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics"
)

# Check for NaN values
val0 = np.array(property_values[0])
Expand Down

0 comments on commit 0f6f21b

Please sign in to comment.