Skip to content

Commit

Permalink
Merge pull request #2079 from samuelgarcia/fix_plot_traces
Browse files Browse the repository at this point in the history
Fix plot_traces with ipywidgets
  • Loading branch information
yger authored Oct 6, 2023
2 parents 07623c5 + 2907934 commit a2d27ff
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 33 deletions.
64 changes: 33 additions & 31 deletions src/spikeinterface/widgets/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,32 @@ def __init__(
else:
raise ValueError("plot_traces recording must be recording or dict or list")

layer_keys = list(recordings.keys())
if rec0.has_channel_location():
channel_locations = rec0.get_channel_locations()
else:
channel_locations = None

if segment_index is None:
if rec0.get_num_segments() != 1:
raise ValueError("You must provide segment_index=...")
segment_index = 0
if order_channel_by_depth and channel_locations is not None:
from ..preprocessing import depth_order

rec0 = depth_order(rec0)
recordings = {k: depth_order(rec) for k, rec in recordings.items()}

if channel_ids is not None:
# ensure that channel_ids are in the good order
channel_ids_ = list(rec0.channel_ids)
order = np.argsort([channel_ids_.index(c) for c in channel_ids])
channel_ids = list(np.array(channel_ids)[order])

if channel_ids is None:
channel_ids = rec0.channel_ids

if "location" in rec0.get_property_keys():
channel_locations = rec0.get_channel_locations()
else:
channel_locations = None
layer_keys = list(recordings.keys())

if order_channel_by_depth:
if channel_locations is not None:
order, _ = order_channels_by_depth(rec0, channel_ids)
else:
order = None
if segment_index is None:
if rec0.get_num_segments() != 1:
raise ValueError("You must provide segment_index=...")
segment_index = 0

fs = rec0.get_sampling_frequency()
if time_range is None:
Expand All @@ -124,7 +130,7 @@ def __init__(
cmap = cmap

times, list_traces, frame_range, channel_ids = _get_trace_list(
recordings, channel_ids, time_range, segment_index, order, return_scaled
recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled
)

# stat for auto scaling done on the first layer
Expand All @@ -138,9 +144,10 @@ def __init__(

# colors is a nested dict by layer and channels
# lets first create black for all channels and layer
# all color are generated for ipywidgets
colors = {}
for k in layer_keys:
colors[k] = {chan_id: "k" for chan_id in channel_ids}
colors[k] = {chan_id: "k" for chan_id in rec0.channel_ids}

if color_groups:
channel_groups = rec0.get_channel_groups(channel_ids=channel_ids)
Expand All @@ -149,7 +156,7 @@ def __init__(
group_colors = get_some_colors(groups, color_engine="auto")

channel_colors = {}
for i, chan_id in enumerate(channel_ids):
for i, chan_id in enumerate(rec0.channel_ids):
group = channel_groups[i]
channel_colors[chan_id] = group_colors[group]

Expand All @@ -159,12 +166,12 @@ def __init__(
elif color is not None:
# old behavior one color for all channel
# if multi layer then black for all
colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids}
colors[layer_keys[0]] = {chan_id: color for chan_id in rec0.channel_ids}
elif color is None and len(recordings) > 1:
# several layer
layer_colors = get_some_colors(layer_keys)
for k in layer_keys:
colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids}
colors[k] = {chan_id: layer_colors[k] for chan_id in rec0.channel_ids}
else:
# color is None unique layer : all channels black
pass
Expand Down Expand Up @@ -201,7 +208,6 @@ def __init__(
show_channel_ids=show_channel_ids,
add_legend=add_legend,
order_channel_by_depth=order_channel_by_depth,
order=order,
tile_size=tile_size,
num_timepoints_per_row=int(seconds_per_row * fs),
return_scaled=return_scaled,
Expand Down Expand Up @@ -336,6 +342,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
)
self.scaler = ScaleWidget()
self.channel_selector = ChannelSelector(self.rec0.channel_ids)
self.channel_selector.value = list(data_plot["channel_ids"])

left_sidebar = W.VBox(
children=[
Expand Down Expand Up @@ -398,17 +405,17 @@ def _mode_changed(self, change=None):
def _retrieve_traces(self, change=None):
channel_ids = np.array(self.channel_selector.value)

if self.data_plot["order_channel_by_depth"]:
order, _ = order_channels_by_depth(self.rec0, channel_ids)
else:
order = None
# if self.data_plot["order_channel_by_depth"]:
# order, _ = order_channels_by_depth(self.rec0, channel_ids)
# else:
# order = None

start_frame, end_frame, segment_index = self.time_slider.value
time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency

self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()}
times, list_traces, frame_range, channel_ids = _get_trace_list(
self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled
self._selected_recordings, channel_ids, time_range, segment_index, return_scaled=self.return_scaled
)

self._channel_ids = channel_ids
Expand Down Expand Up @@ -523,7 +530,7 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs):
app.exec()


def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False):
def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_scaled=False):
# function also used in ipywidgets plotter
k0 = list(recordings.keys())[0]
rec0 = recordings[k0]
Expand All @@ -550,11 +557,6 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=No
return_scaled=return_scaled,
)

if order is not None:
traces = traces[:, order]
list_traces.append(traces)

if order is not None:
channel_ids = np.array(channel_ids)[order]

return times, list_traces, frame_range, channel_ids
15 changes: 13 additions & 2 deletions src/spikeinterface/widgets/utils_ipywidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,7 @@ def __init__(self, channel_ids, **kwargs):
self.slider.observe(self.on_slider_changed, names=["value"], type="change")
self.selector.observe(self.on_selector_changed, names=["value"], type="change")

# TODO external value change
# self.observe(self.value_changed, names=['value'], type="change")
self.observe(self.value_changed, names=["value"], type="change")

def on_slider_changed(self, change=None):
i0, i1 = self.slider.value
Expand All @@ -260,6 +259,18 @@ def on_selector_changed(self, change=None):

self.value = channel_ids

def value_changed(self, change=None):
self.selector.unobserve(self.on_selector_changed, names=["value"], type="change")
self.selector.value = change["new"]
self.selector.observe(self.on_selector_changed, names=["value"], type="change")

channel_ids = self.selector.value
self.slider.unobserve(self.on_slider_changed, names=["value"], type="change")
i0 = self.channel_ids.index(channel_ids[0])
i1 = self.channel_ids.index(channel_ids[-1]) + 1
self.slider.value = (i0, i1)
self.slider.observe(self.on_slider_changed, names=["value"], type="change")


class ScaleWidget(W.VBox):
value = traitlets.Float()
Expand Down

0 comments on commit a2d27ff

Please sign in to comment.