From 3448e1ec4b19d5f5091ba6a2792362cf35a9f941 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 08:57:56 +0200 Subject: [PATCH 1/6] Fix plot_traces with ipywidgets when channel_ids is not None --- src/spikeinterface/widgets/traces.py | 10 ++++++---- src/spikeinterface/widgets/utils_ipywidgets.py | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 9b6716e8f3..2783b6a369 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -138,9 +138,10 @@ def __init__( # colors is a nested dict by layer and channels # lets first create black for all channels and layer + # all color are generated for ipywidgets colors = {} for k in layer_keys: - colors[k] = {chan_id: "k" for chan_id in channel_ids} + colors[k] = {chan_id: "k" for chan_id in rec0.channel_ids} if color_groups: channel_groups = rec0.get_channel_groups(channel_ids=channel_ids) @@ -149,7 +150,7 @@ def __init__( group_colors = get_some_colors(groups, color_engine="auto") channel_colors = {} - for i, chan_id in enumerate(channel_ids): + for i, chan_id in enumerate(rec0.channel_ids): group = channel_groups[i] channel_colors[chan_id] = group_colors[group] @@ -159,12 +160,12 @@ def __init__( elif color is not None: # old behavior one color for all channel # if multi layer then black for all - colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids} + colors[layer_keys[0]] = {chan_id: color for chan_id in rec0.channel_ids} elif color is None and len(recordings) > 1: # several layer layer_colors = get_some_colors(layer_keys) for k in layer_keys: - colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids} + colors[k] = {chan_id: layer_colors[k] for chan_id in rec0.channel_ids} else: # color is None unique layer : all channels black pass @@ -336,6 +337,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) + self.channel_selector.value = data_plot["channel_ids"] left_sidebar = W.VBox( children=[ diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 6e872eca55..5bbe31302c 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -235,8 +235,7 @@ def __init__(self, channel_ids, **kwargs): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.selector.observe(self.on_selector_changed, names=["value"], type="change") - # TODO external value change - # self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=['value'], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value @@ -259,6 +258,19 @@ def on_selector_changed(self, change=None): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.value = channel_ids + + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + channel_ids = self.selector.value + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + class ScaleWidget(W.VBox): From e51bb75f226c7c2be97c4a6ceeae460a7c610efe Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 09:25:35 +0200 Subject: [PATCH 2/6] Fix order_channel_by_depth in ipywidgets Fix order_channel_by_depth when channel_ids is given. --- src/spikeinterface/widgets/traces.py | 58 +++++++++++++++------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 2783b6a369..802f90c62a 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,6 +88,26 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") + if "location" in rec0.get_property_keys(): + channel_locations = rec0.get_channel_locations() + else: + channel_locations = None + + if order_channel_by_depth and channel_locations is not None: + from ..preprocessing import depth_order + rec0 = depth_order(rec0) + recordings = {k: depth_order(rec) for k, rec in recordings.items()} + + if channel_ids is not None: + # ensure that channel_ids are in the good order + channel_ids_ = list(rec0.channel_ids) + order = np.argsort([channel_ids_.index(c) for c in channel_ids]) + channel_ids = list(np.array(channel_ids)[order]) + + if channel_ids is None: + channel_ids = rec0.channel_ids + + layer_keys = list(recordings.keys()) if segment_index is None: @@ -95,19 +115,6 @@ def __init__( raise ValueError("You must provide segment_index=...") segment_index = 0 - if channel_ids is None: - channel_ids = rec0.channel_ids - - if "location" in rec0.get_property_keys(): - channel_locations = rec0.get_channel_locations() - else: - channel_locations = None - - if order_channel_by_depth: - if channel_locations is not None: - order, _ = order_channels_by_depth(rec0, channel_ids) - else: - order = None fs = rec0.get_sampling_frequency() if time_range is None: @@ -124,7 +131,7 @@ def __init__( cmap = cmap times, list_traces, frame_range, channel_ids = _get_trace_list( - recordings, channel_ids, time_range, segment_index, order, return_scaled + recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled ) # stat for auto scaling done on the first layer @@ -202,7 +209,6 @@ def __init__( show_channel_ids=show_channel_ids, add_legend=add_legend, order_channel_by_depth=order_channel_by_depth, - order=order, tile_size=tile_size, num_timepoints_per_row=int(seconds_per_row * fs), return_scaled=return_scaled, @@ -337,7 +343,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) - self.channel_selector.value = data_plot["channel_ids"] + self.channel_selector.value = list(data_plot["channel_ids"]) left_sidebar = W.VBox( children=[ @@ -400,17 +406,17 @@ def _mode_changed(self, change=None): def _retrieve_traces(self, change=None): channel_ids = np.array(self.channel_selector.value) - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None + # if self.data_plot["order_channel_by_depth"]: + # order, _ = order_channels_by_depth(self.rec0, channel_ids) + # else: + # order = None start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times, list_traces, frame_range, channel_ids = _get_trace_list( - self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled + self._selected_recordings, channel_ids, time_range, segment_index, return_scaled=self.return_scaled ) self._channel_ids = channel_ids @@ -525,7 +531,7 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs): app.exec() -def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): +def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_scaled=False): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] rec0 = recordings[k0] @@ -552,11 +558,11 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=No return_scaled=return_scaled, ) - if order is not None: - traces = traces[:, order] + # if order is not None: + # traces = traces[:, order] list_traces.append(traces) - if order is not None: - channel_ids = np.array(channel_ids)[order] + # if order is not None: + # channel_ids = np.array(channel_ids)[order] return times, list_traces, frame_range, channel_ids From 5c5f32fb0df19cb5faf7e24c11758639c1740f18 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 09:53:33 +0200 Subject: [PATCH 3/6] yep --- src/spikeinterface/widgets/traces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 802f90c62a..d010c96a27 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,7 +88,7 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") - if "location" in rec0.get_property_keys(): + if rec0.has_channel_locations(): channel_locations = rec0.get_channel_locations() else: channel_locations = None From 63494f2a44424085d7ad22935313f9cbd2c8b88c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 09:11:43 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/traces.py | 3 +-- src/spikeinterface/widgets/utils_ipywidgets.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index d010c96a27..7a4306b284 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -95,6 +95,7 @@ def __init__( if order_channel_by_depth and channel_locations is not None: from ..preprocessing import depth_order + rec0 = depth_order(rec0) recordings = {k: depth_order(rec) for k, rec in recordings.items()} @@ -107,7 +108,6 @@ def __init__( if channel_ids is None: channel_ids = rec0.channel_ids - layer_keys = list(recordings.keys()) if segment_index is None: @@ -115,7 +115,6 @@ def __init__( raise ValueError("You must provide segment_index=...") segment_index = 0 - fs = rec0.get_sampling_frequency() if time_range is None: time_range = (0, 1.0) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 5bbe31302c..58dd5c7f32 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -235,7 +235,7 @@ def __init__(self, channel_ids, **kwargs): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.selector.observe(self.on_selector_changed, names=["value"], type="change") - self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value @@ -258,7 +258,7 @@ def on_selector_changed(self, change=None): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.value = channel_ids - + def value_changed(self, change=None): self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") self.selector.value = change["new"] @@ -272,7 +272,6 @@ def value_changed(self, change=None): self.slider.observe(self.on_slider_changed, names=["value"], type="change") - class ScaleWidget(W.VBox): value = traitlets.Float() From c0d4c60095f9704f9b27adfb5fa0f4867adfaf10 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 11:38:15 +0200 Subject: [PATCH 5/6] oups --- src/spikeinterface/widgets/traces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index d010c96a27..ce34af0bfa 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,7 +88,7 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") - if rec0.has_channel_locations(): + if rec0.has_channel_location(): channel_locations = rec0.get_channel_locations() else: channel_locations = None From 2907934928719cf8d0403a2c55628645483187f7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 11:48:37 +0200 Subject: [PATCH 6/6] clean --- src/spikeinterface/widgets/traces.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 5a8212302c..fc8b30eb05 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -557,11 +557,6 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_s return_scaled=return_scaled, ) - # if order is not None: - # traces = traces[:, order] list_traces.append(traces) - # if order is not None: - # channel_ids = np.array(channel_ids)[order] - return times, list_traces, frame_range, channel_ids