diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 8317d7bec4..b695c7d627 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -202,7 +202,7 @@ jobs: shell: bash if: env.RUN_WIDGETS_TESTS == 'true' run: | - pip install -e .[full] + pip install -e .[full,widgets] ./.github/run_tests.sh widgets --no-virtual-env - name: Test exporters diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index e06c79ad2f..debcd52085 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -539,6 +539,21 @@ def test_plot_sorting_summary(self): backend=backend, **self.backend_kwargs[backend], ) + # add unit_properties + sw.plot_sorting_summary( + self.sorting_analyzer_sparse, + unit_table_properties=["firing_rate", "snr"], + backend=backend, + **self.backend_kwargs[backend], + ) + # adding a missing property should raise a warning + with self.assertWarns(UserWarning): + sw.plot_sorting_summary( + self.sorting_analyzer_sparse, + unit_table_properties=["missing_property"], + backend=backend, + **self.backend_kwargs[backend], + ) def test_plot_agreement_matrix(self): possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 269193b341..7a9dc47826 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -2,6 +2,7 @@ import numpy as np +from ..core import SortingAnalyzer, BaseSorting from ..core.core_tools import check_json from warnings import warn @@ -46,26 +47,42 @@ def handle_display_and_url(widget, view, **backend_kwargs): return url -def generate_unit_table_view(analyzer, unit_properties=None, similarity_scores=None): +def generate_unit_table_view( + sorting_or_sorting_analyzer: SortingAnalyzer | BaseSorting, + unit_properties: list[str] | None = None, + similarity_scores: npndarray | None = None, +): import sortingview.views as vv - sorting = analyzer.sorting + if isinstance(sorting_or_sorting_analyzer, SortingAnalyzer): + analyzer = sorting_or_sorting_analyzer + sorting = analyzer.sorting + else: + sorting = sorting_or_sorting_analyzer + analyzer = None # Find available unit properties from all sources sorting_props = list(sorting.get_property_keys()) - if analyzer.get_extension("quality_metrics") is not None: - qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns) - qm_data = analyzer.get_extension("quality_metrics").get_data() + if analyzer is not None: + if analyzer.get_extension("quality_metrics") is not None: + qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns) + qm_data = analyzer.get_extension("quality_metrics").get_data() + else: + qm_props = [] + if analyzer.get_extension("template_metrics") is not None: + tm_props = list(analyzer.get_extension("template_metrics").get_data().columns) + tm_data = analyzer.get_extension("template_metrics").get_data() + else: + tm_props = [] + # Check for any overlaps and warn user if any + all_props = sorting_props + qm_props + tm_props else: + all_props = sorting_props qm_props = [] - if analyzer.get_extension("template_metrics") is not None: - tm_props = list(analyzer.get_extension("template_metrics").get_data().columns) - tm_data = analyzer.get_extension("template_metrics").get_data() - else: tm_props = [] + qm_data = None + tm_data = None - # Check for any overlaps and warn user if any - all_props = sorting_props + qm_props + tm_props overlap_props = [prop for prop in all_props if all_props.count(prop) > 1] if len(overlap_props) > 0: warn( @@ -93,7 +110,8 @@ def generate_unit_table_view(analyzer, unit_properties=None, similarity_scores=N elif prop_name in tm_props: property_values = tm_data[prop_name].values else: - raise ValueError(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") + warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") + continue # make dtype available val0 = np.array(property_values[0]) @@ -106,7 +124,7 @@ def generate_unit_table_view(analyzer, unit_properties=None, similarity_scores=N elif val0.dtype.kind == "b": dtype = "bool" else: - print(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") + warn(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") continue ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) valid_unit_properties.append(prop_name) @@ -122,10 +140,6 @@ def generate_unit_table_view(analyzer, unit_properties=None, similarity_scores=N property_values = qm_data[prop_name].values elif prop_name in tm_props: property_values = tm_data[prop_name].values - else: - raise ValueError( - f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics" - ) # Check for NaN values val0 = np.array(property_values[0])