Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enh browserfig UI events #12819

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mne/viz/_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,8 @@ def _update_data(self):

def _get_epoch_num_from_time(self, time):
epoch_nums = self.mne.inst.selection
return epoch_nums[np.searchsorted(self.mne.boundary_times[1:], time)]
epoch_ix = np.searchsorted(self.mne.boundary_times[1:-1], time)
return epoch_nums[epoch_ix]

def _redraw(self, update_data=True, annotations=False):
"""Redraws backend if necessary."""
Expand Down
199 changes: 159 additions & 40 deletions mne/viz/_mpl_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from ..fixes import _close_event
from ..utils import Bunch, _click_ch_name, check_version, logger
from ._figure import BrowserBase
from .ui_events import ChannelBrowse, TimeBrowse, TimeChange, publish, subscribe
from .utils import (
DraggableLine,
_events_off,
Expand Down Expand Up @@ -566,6 +567,15 @@ def __init__(self, inst, figsize, ica=None, xlabel="Time (s)", **kwargs):
vline_text=vline_text,
)

# Start listening to incoming TimeChange UI events
subscribe(self, "time_change", self._on_time_change_event)

# Start listening to incoming TimeBrowse UI events
subscribe(self, "time_browse", self._on_time_browse_event)

# Start listening to incoming ChannelBrowse UI events
subscribe(self, "channel_browse", self._on_channel_browse_event)

def _get_size(self):
return self.get_size_inches()

Expand Down Expand Up @@ -641,7 +651,7 @@ def _keypress(self, event):
key = event.key
n_channels = self.mne.n_channels
if self.mne.is_epochs:
last_time = self.mne.n_times / self.mne.info["sfreq"]
last_time = self.mne.boundary_times[-2]
else:
last_time = self.mne.inst.times[-1]
# scroll up/down
Expand Down Expand Up @@ -674,24 +684,21 @@ def _keypress(self, event):
else:
ceiling = len(self.mne.ch_order) - n_channels
ch_start = self.mne.ch_start + direction * n_channels
self.mne.ch_start = np.clip(ch_start, 0, ceiling)
self._update_picks()
self._update_vscroll()
self._redraw()
ch_start = np.clip(ch_start, 0, ceiling)
channels = self.mne.ch_names[
self.mne.ch_order[ch_start : ch_start + n_channels]
]
publish(self, ChannelBrowse(channels=channels))
# scroll left/right
elif key in ("right", "left", "shift+right", "shift+left"):
old_t_start = self.mne.t_start
direction = 1 if key.endswith("right") else -1
if self.mne.is_epochs:
denom = 1 if key.startswith("shift") else self.mne.n_epochs
else:
denom = 1 if key.startswith("shift") else 4
t_max = last_time - self.mne.duration
t_start = self.mne.t_start + direction * self.mne.duration / denom
self.mne.t_start = np.clip(t_start, self.mne.first_time, t_max)
if self.mne.t_start != old_t_start:
self._update_hscroll()
self._redraw(annotations=True)
t_start = np.clip(t_start, 0, last_time - self.mne.duration)
self._publish_time_browse_event(t_start)
# scale traces
elif key in ("=", "+", "-"):
scaler = 1 / 1.1 if key == "-" else 1.1
Expand All @@ -704,41 +711,31 @@ def _keypress(self, event):
and not self.mne.butterfly
):
new_n_ch = n_channels + (1 if key == "pageup" else -1)
self.mne.n_channels = np.clip(new_n_ch, 1, len(self.mne.ch_order))
n_channels = np.clip(new_n_ch, 1, len(self.mne.ch_order))
ch_start = self.mne.ch_start
# add new chs from above if we're at the bottom of the scrollbar
ch_end = self.mne.ch_start + self.mne.n_channels
if ch_end > len(self.mne.ch_order) and self.mne.ch_start > 0:
self.mne.ch_start -= 1
self._update_vscroll()
# redraw only if changed
if self.mne.n_channels != n_channels:
self._update_picks()
self._update_trace_offsets()
self._redraw(annotations=True)
ch_end = ch_start + n_channels
if ch_end > len(self.mne.ch_order) and ch_start > 0:
ch_start -= 1
channels = self.mne.ch_names[
self.mne.ch_order[ch_start : ch_start + n_channels]
]
publish(self, ChannelBrowse(channels=channels))

# change duration
elif key in ("home", "end"):
old_dur = self.mne.duration
dur_delta = 1 if key == "end" else -1
if self.mne.is_epochs:
# prevent from showing zero epochs, or more epochs than we have
self.mne.n_epochs = np.clip(
self.mne.n_epochs + dur_delta, 1, len(self.mne.inst)
)
# use the length of one epoch as duration change
min_dur = len(self.mne.inst.times) / self.mne.info["sfreq"]
new_dur = self.mne.duration + dur_delta * min_dur
else:
# never show fewer than 3 samples
min_dur = 3 * np.diff(self.mne.inst.times[:2])[0]
# use multiplicative dur_delta
dur_delta = 5 / 4 if dur_delta > 0 else 4 / 5
new_dur = self.mne.duration * dur_delta
self.mne.duration = np.clip(new_dur, min_dur, last_time)
if self.mne.duration != old_dur:
if self.mne.t_start + self.mne.duration > last_time:
self.mne.t_start = last_time - self.mne.duration
self._update_hscroll()
self._redraw(annotations=True)
self._publish_time_browse_event(
self.mne.t_start, self.mne.t_start + new_dur
)
elif key == "?": # help window
self._toggle_help_fig(event)
elif key == "a": # annotation mode
Expand Down Expand Up @@ -797,7 +794,13 @@ def _buttonpress(self, event):
idx = self.mne.traces.index(line)
self._toggle_bad_channel(idx)
return
self._show_vline(event.xdata) # butterfly / not on data trace
time = event.xdata
if self.mne.is_epochs:
width = self.mne.boundary_times[1] - self.mne.boundary_times[0]
time = (time % width) + self.mne.inst.tmin
else:
time += self.mne.inst.first_time
publish(self, TimeChange(time=time))
self._redraw(update_data=False, annotations=False)
return
# click in vertical scrollbar
Expand Down Expand Up @@ -1752,8 +1755,7 @@ def _check_update_hscroll_clicked(self, event):
ix = np.searchsorted(self.mne.boundary_times[1:], time)
time = self.mne.boundary_times[ix]
if self.mne.t_start != time:
self.mne.t_start = time
self._update_hscroll()
self._publish_time_browse_event(time)
return True
return False

Expand All @@ -1765,9 +1767,10 @@ def _check_update_vscroll_clicked(self, event):
len(self.mne.ch_order) - self.mne.n_channels,
)
if self.mne.ch_start != new_ch_start:
self.mne.ch_start = new_ch_start
self._update_picks()
self._update_vscroll()
channels = self.mne.ch_names[
self.mne.ch_order[new_ch_start : new_ch_start + self.mne.n_channels]
]
publish(self, ChannelBrowse(channels=channels))
return True
return False

Expand Down Expand Up @@ -2318,6 +2321,122 @@ def _get_scale_bar_texts(self):

return texts

def _publish_time_browse_event(self, t_start=None, t_end=None):
"""Publish a TimeBrowse event with meaningful time_start and time_end values."""
# Figure out proper t_start and t_end that doesn't exceed the data boundaries.
if t_start is None:
t_start = self.mne.t_start
else:
if self.mne.is_epochs:
last_time = self.mne.n_times / self.mne.info["sfreq"]
else:
last_time = self.mne.inst.times[-1]
t_max = last_time - self.mne.duration
t_start = np.clip(t_start, self.mne.first_time, t_max)

if t_end is None:
t_end = t_start + self.mne.duration
else:
t_end = min(t_end, self.mne.n_times / self.mne.info["sfreq"])

# Don't publish an event if nothing changed.
if (
self.mne.t_start == t_start
and self.mne.t_start + self.mne.duration == t_end
):
return

if self.mne.is_epochs:
# Translate the time-coordinate in the browser window to the actual
# start/end times of the epochs in the raw file.
epoch_num_start = self._get_epoch_num_from_time(
t_start + self.mne.sampling_period
)
epoch_num_end = self._get_epoch_num_from_time(
t_end + self.mne.sampling_period
)
onsets = self.mne.inst.events[:, 0] / self.mne.info["sfreq"]
t_start = onsets[epoch_num_start] + self.mne.inst.tmin
t_end = onsets[epoch_num_end - 1] + self.mne.inst.tmax
else:
# For raw data, we need to take `first_time` into account.
t_start += self.mne.inst.first_time
t_end += self.mne.inst.first_time

publish(self, TimeBrowse(time_start=t_start, time_end=t_end))

def _on_time_browse_event(self, event):
"""Respond to the TimeBrowse UI event, update horizontal scrolling."""
time_start = event.time_start
time_end = event.time_end

if self.mne.is_epochs:
# Translate the start/end times from the original raw to the indices of the
# epochs being shown, and then to the appropriate start/end times in the
# browser.
events = self.mne.inst.events[self.mne.inst.selection]
onsets = events[:, 0] / self.mne.info["sfreq"]
# Subtract/add one sample to make sure we end on the right side. This is
# needed because of small floating point errors.
epoch_ix_start = np.searchsorted(
onsets, time_start - self.mne.inst.tmin - self.mne.sampling_period
)
epoch_ix_end = np.searchsorted(
onsets, time_end - self.mne.inst.tmax + self.mne.sampling_period
)
# Always show at least one epoch.
epoch_ix_start = min(epoch_ix_start, len(self.mne.inst) - 1)
epoch_ix_end = max(epoch_ix_start + 1, epoch_ix_end)
self.mne.n_epochs = epoch_ix_end - epoch_ix_start

# Compute the browser time period to match the selected epochs.
time_start = self.mne.boundary_times[epoch_ix_start]
width = self.mne.boundary_times[1] - self.mne.boundary_times[0]
time_end = self.mne.boundary_times[epoch_ix_end - 1] + width
else:
# For raw data, we need to take `first_time` into account.
time_start -= self.mne.inst.first_time
time_end -= self.mne.inst.first_time

# Never show fewer than 3 samples.
min_dur = 3 * np.diff(self.mne.inst.times[:2])[0]
time_end = np.clip(time_end, time_start + min_dur, self.mne.inst.times[-1])

# Update browser window.
self.mne.t_start = time_start
self.mne.duration = time_end - time_start
self._update_hscroll()
self._redraw(annotations=True)

def _on_channel_browse_event(self, event):
"""Respond to the ChannelBrowse UI event."""
old_n_channels = self.mne.n_channels
picks = np.flatnonzero(
np.isin(self.mne.ch_names[self.mne.ch_order], event.channels)
)
if len(picks) == 0:
return # can't handle the event
if picks.min() == self.mne.ch_start and len(picks) == self.mne.n_channels:
return # no change

self.mne.ch_start = picks.min()
self.mne.n_channels = len(picks)
self._update_vscroll()
self._update_picks()
if self.mne.n_channels != old_n_channels:
self._update_trace_offsets()
self._redraw(annotations=True)

def _on_time_change_event(self, event):
"""Respond to the TimeChange UI event."""
if self.mne.is_epochs:
time = np.clip(event.time, self.mne.inst.tmin, self.mne.inst.tmax)
time -= self.mne.inst.tmin
else:
time = event.time - self.mne.inst.first_time
time = np.clip(time, self.mne.inst.times[0], self.mne.inst.times[-1])
self._show_vline(time)


class MNELineFigure(MNEFigure):
"""Interactive figure for non-scrolling line plots."""
Expand Down
Loading
Loading