From 884be4fc553cb3e56cee1e3c8742aab8834a20bd Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 1 May 2024 16:27:29 -0400 Subject: [PATCH 1/2] add errors to `ensure` functions to help end user --- src/spikeinterface/widgets/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 7440c240ce..b94167d2b7 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -109,17 +109,17 @@ def do_plot(self): @classmethod def ensure_sorting_analyzer(cls, input): - # internal help to accept both SortingAnalyzer or MockWaveformExtractor for a ploter + # internal help to accept both SortingAnalyzer or MockWaveformExtractor for a plotter if isinstance(input, SortingAnalyzer): return input elif isinstance(input, MockWaveformExtractor): return input.sorting_analyzer else: - return input + raise TypeError("input must be a SortingAnalyzer or MockWaveformExtractor") @classmethod def ensure_sorting(cls, input): - # internal help to accept both Sorting or SortingAnalyzer or MockWaveformExtractor for a ploter + # internal help to accept both Sorting or SortingAnalyzer or MockWaveformExtractor for a plotter if isinstance(input, BaseSorting): return input elif isinstance(input, SortingAnalyzer): @@ -127,7 +127,7 @@ def ensure_sorting(cls, input): elif isinstance(input, MockWaveformExtractor): return input.sorting_analyzer.sorting else: - return input + raise TypeError("input must be a SortingAnalyzer, MockWaveformExtractor, or of type BaseSorting") @staticmethod def check_extensions(sorting_analyzer, extensions): From 0c96661505cf354b520eeb3934ad3605b862dafc Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 1 May 2024 16:50:07 -0400 Subject: [PATCH 2/2] allow for basesorting --- src/spikeinterface/widgets/crosscorrelograms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 6eb565d56a..e70a5775e6 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -46,7 +46,9 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_analyzer_or_sorting = self.ensure_sorting_analyzer(sorting_analyzer_or_sorting) + + if not isinstance(sorting_analyzer_or_sorting, BaseSorting): + sorting_analyzer_or_sorting = self.ensure_sorting_analyzer(sorting_analyzer_or_sorting) if min_similarity_for_correlograms is None: min_similarity_for_correlograms = 0