diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 9b6716e8f3..fc8b30eb05 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,26 +88,32 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") - layer_keys = list(recordings.keys()) + if rec0.has_channel_location(): + channel_locations = rec0.get_channel_locations() + else: + channel_locations = None - if segment_index is None: - if rec0.get_num_segments() != 1: - raise ValueError("You must provide segment_index=...") - segment_index = 0 + if order_channel_by_depth and channel_locations is not None: + from ..preprocessing import depth_order + + rec0 = depth_order(rec0) + recordings = {k: depth_order(rec) for k, rec in recordings.items()} + + if channel_ids is not None: + # ensure that channel_ids are in the good order + channel_ids_ = list(rec0.channel_ids) + order = np.argsort([channel_ids_.index(c) for c in channel_ids]) + channel_ids = list(np.array(channel_ids)[order]) if channel_ids is None: channel_ids = rec0.channel_ids - if "location" in rec0.get_property_keys(): - channel_locations = rec0.get_channel_locations() - else: - channel_locations = None + layer_keys = list(recordings.keys()) - if order_channel_by_depth: - if channel_locations is not None: - order, _ = order_channels_by_depth(rec0, channel_ids) - else: - order = None + if segment_index is None: + if rec0.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") + segment_index = 0 fs = rec0.get_sampling_frequency() if time_range is None: @@ -124,7 +130,7 @@ def __init__( cmap = cmap times, list_traces, frame_range, channel_ids = _get_trace_list( - recordings, channel_ids, time_range, segment_index, order, return_scaled + recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled ) # stat for auto scaling done on the first layer @@ -138,9 +144,10 @@ def __init__( # colors is a nested dict by layer and channels # lets first create black for all channels and layer + # all color are generated for ipywidgets colors = {} for k in layer_keys: - colors[k] = {chan_id: "k" for chan_id in channel_ids} + colors[k] = {chan_id: "k" for chan_id in rec0.channel_ids} if color_groups: channel_groups = rec0.get_channel_groups(channel_ids=channel_ids) @@ -149,7 +156,7 @@ def __init__( group_colors = get_some_colors(groups, color_engine="auto") channel_colors = {} - for i, chan_id in enumerate(channel_ids): + for i, chan_id in enumerate(rec0.channel_ids): group = channel_groups[i] channel_colors[chan_id] = group_colors[group] @@ -159,12 +166,12 @@ def __init__( elif color is not None: # old behavior one color for all channel # if multi layer then black for all - colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids} + colors[layer_keys[0]] = {chan_id: color for chan_id in rec0.channel_ids} elif color is None and len(recordings) > 1: # several layer layer_colors = get_some_colors(layer_keys) for k in layer_keys: - colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids} + colors[k] = {chan_id: layer_colors[k] for chan_id in rec0.channel_ids} else: # color is None unique layer : all channels black pass @@ -201,7 +208,6 @@ def __init__( show_channel_ids=show_channel_ids, add_legend=add_legend, order_channel_by_depth=order_channel_by_depth, - order=order, tile_size=tile_size, num_timepoints_per_row=int(seconds_per_row * fs), return_scaled=return_scaled, @@ -336,6 +342,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) + self.channel_selector.value = list(data_plot["channel_ids"]) left_sidebar = W.VBox( children=[ @@ -398,17 +405,17 @@ def _mode_changed(self, change=None): def _retrieve_traces(self, change=None): channel_ids = np.array(self.channel_selector.value) - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None + # if self.data_plot["order_channel_by_depth"]: + # order, _ = order_channels_by_depth(self.rec0, channel_ids) + # else: + # order = None start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times, list_traces, frame_range, channel_ids = _get_trace_list( - self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled + self._selected_recordings, channel_ids, time_range, segment_index, return_scaled=self.return_scaled ) self._channel_ids = channel_ids @@ -523,7 +530,7 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs): app.exec() -def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): +def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_scaled=False): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] rec0 = recordings[k0] @@ -550,11 +557,6 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=No return_scaled=return_scaled, ) - if order is not None: - traces = traces[:, order] list_traces.append(traces) - if order is not None: - channel_ids = np.array(channel_ids)[order] - return times, list_traces, frame_range, channel_ids diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 6e872eca55..58dd5c7f32 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -235,8 +235,7 @@ def __init__(self, channel_ids, **kwargs): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.selector.observe(self.on_selector_changed, names=["value"], type="change") - # TODO external value change - # self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value @@ -260,6 +259,18 @@ def on_selector_changed(self, change=None): self.value = channel_ids + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + channel_ids = self.selector.value + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + class ScaleWidget(W.VBox): value = traitlets.Float()