diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index ffb1c57dafd..a2d83bd607e 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -412,7 +412,8 @@ def _update_data(self): def _get_epoch_num_from_time(self, time): epoch_nums = self.mne.inst.selection - return epoch_nums[np.searchsorted(self.mne.boundary_times[1:], time)] + epoch_ix = np.searchsorted(self.mne.boundary_times[1:-1], time) + return epoch_nums[epoch_ix] def _redraw(self, update_data=True, annotations=False): """Redraws backend if necessary.""" diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 2e552bd4012..46d5f20d687 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -58,6 +58,7 @@ from ..fixes import _close_event from ..utils import Bunch, _click_ch_name, check_version, logger from ._figure import BrowserBase +from .ui_events import ChannelBrowse, TimeBrowse, TimeChange, publish, subscribe from .utils import ( DraggableLine, _events_off, @@ -566,6 +567,15 @@ def __init__(self, inst, figsize, ica=None, xlabel="Time (s)", **kwargs): vline_text=vline_text, ) + # Start listening to incoming TimeChange UI events + subscribe(self, "time_change", self._on_time_change_event) + + # Start listening to incoming TimeBrowse UI events + subscribe(self, "time_browse", self._on_time_browse_event) + + # Start listening to incoming ChannelBrowse UI events + subscribe(self, "channel_browse", self._on_channel_browse_event) + def _get_size(self): return self.get_size_inches() @@ -641,7 +651,7 @@ def _keypress(self, event): key = event.key n_channels = self.mne.n_channels if self.mne.is_epochs: - last_time = self.mne.n_times / self.mne.info["sfreq"] + last_time = self.mne.boundary_times[-2] else: last_time = self.mne.inst.times[-1] # scroll up/down @@ -674,24 +684,21 @@ def _keypress(self, event): else: ceiling = len(self.mne.ch_order) - n_channels ch_start = self.mne.ch_start + direction * n_channels - self.mne.ch_start = np.clip(ch_start, 0, ceiling) - self._update_picks() - self._update_vscroll() - self._redraw() + ch_start = np.clip(ch_start, 0, ceiling) + channels = self.mne.ch_names[ + self.mne.ch_order[ch_start : ch_start + n_channels] + ] + publish(self, ChannelBrowse(channels=channels)) # scroll left/right elif key in ("right", "left", "shift+right", "shift+left"): - old_t_start = self.mne.t_start direction = 1 if key.endswith("right") else -1 if self.mne.is_epochs: denom = 1 if key.startswith("shift") else self.mne.n_epochs else: denom = 1 if key.startswith("shift") else 4 - t_max = last_time - self.mne.duration t_start = self.mne.t_start + direction * self.mne.duration / denom - self.mne.t_start = np.clip(t_start, self.mne.first_time, t_max) - if self.mne.t_start != old_t_start: - self._update_hscroll() - self._redraw(annotations=True) + t_start = np.clip(t_start, 0, last_time - self.mne.duration) + self._publish_time_browse_event(t_start) # scale traces elif key in ("=", "+", "-"): scaler = 1 / 1.1 if key == "-" else 1.1 @@ -704,41 +711,31 @@ def _keypress(self, event): and not self.mne.butterfly ): new_n_ch = n_channels + (1 if key == "pageup" else -1) - self.mne.n_channels = np.clip(new_n_ch, 1, len(self.mne.ch_order)) + n_channels = np.clip(new_n_ch, 1, len(self.mne.ch_order)) + ch_start = self.mne.ch_start # add new chs from above if we're at the bottom of the scrollbar - ch_end = self.mne.ch_start + self.mne.n_channels - if ch_end > len(self.mne.ch_order) and self.mne.ch_start > 0: - self.mne.ch_start -= 1 - self._update_vscroll() - # redraw only if changed - if self.mne.n_channels != n_channels: - self._update_picks() - self._update_trace_offsets() - self._redraw(annotations=True) + ch_end = ch_start + n_channels + if ch_end > len(self.mne.ch_order) and ch_start > 0: + ch_start -= 1 + channels = self.mne.ch_names[ + self.mne.ch_order[ch_start : ch_start + n_channels] + ] + publish(self, ChannelBrowse(channels=channels)) + # change duration elif key in ("home", "end"): - old_dur = self.mne.duration dur_delta = 1 if key == "end" else -1 if self.mne.is_epochs: - # prevent from showing zero epochs, or more epochs than we have - self.mne.n_epochs = np.clip( - self.mne.n_epochs + dur_delta, 1, len(self.mne.inst) - ) # use the length of one epoch as duration change min_dur = len(self.mne.inst.times) / self.mne.info["sfreq"] new_dur = self.mne.duration + dur_delta * min_dur else: - # never show fewer than 3 samples - min_dur = 3 * np.diff(self.mne.inst.times[:2])[0] # use multiplicative dur_delta dur_delta = 5 / 4 if dur_delta > 0 else 4 / 5 new_dur = self.mne.duration * dur_delta - self.mne.duration = np.clip(new_dur, min_dur, last_time) - if self.mne.duration != old_dur: - if self.mne.t_start + self.mne.duration > last_time: - self.mne.t_start = last_time - self.mne.duration - self._update_hscroll() - self._redraw(annotations=True) + self._publish_time_browse_event( + self.mne.t_start, self.mne.t_start + new_dur + ) elif key == "?": # help window self._toggle_help_fig(event) elif key == "a": # annotation mode @@ -797,7 +794,13 @@ def _buttonpress(self, event): idx = self.mne.traces.index(line) self._toggle_bad_channel(idx) return - self._show_vline(event.xdata) # butterfly / not on data trace + time = event.xdata + if self.mne.is_epochs: + width = self.mne.boundary_times[1] - self.mne.boundary_times[0] + time = (time % width) + self.mne.inst.tmin + else: + time += self.mne.inst.first_time + publish(self, TimeChange(time=time)) self._redraw(update_data=False, annotations=False) return # click in vertical scrollbar @@ -1752,8 +1755,7 @@ def _check_update_hscroll_clicked(self, event): ix = np.searchsorted(self.mne.boundary_times[1:], time) time = self.mne.boundary_times[ix] if self.mne.t_start != time: - self.mne.t_start = time - self._update_hscroll() + self._publish_time_browse_event(time) return True return False @@ -1765,9 +1767,10 @@ def _check_update_vscroll_clicked(self, event): len(self.mne.ch_order) - self.mne.n_channels, ) if self.mne.ch_start != new_ch_start: - self.mne.ch_start = new_ch_start - self._update_picks() - self._update_vscroll() + channels = self.mne.ch_names[ + self.mne.ch_order[new_ch_start : new_ch_start + self.mne.n_channels] + ] + publish(self, ChannelBrowse(channels=channels)) return True return False @@ -2318,6 +2321,122 @@ def _get_scale_bar_texts(self): return texts + def _publish_time_browse_event(self, t_start=None, t_end=None): + """Publish a TimeBrowse event with meaningful time_start and time_end values.""" + # Figure out proper t_start and t_end that doesn't exceed the data boundaries. + if t_start is None: + t_start = self.mne.t_start + else: + if self.mne.is_epochs: + last_time = self.mne.n_times / self.mne.info["sfreq"] + else: + last_time = self.mne.inst.times[-1] + t_max = last_time - self.mne.duration + t_start = np.clip(t_start, self.mne.first_time, t_max) + + if t_end is None: + t_end = t_start + self.mne.duration + else: + t_end = min(t_end, self.mne.n_times / self.mne.info["sfreq"]) + + # Don't publish an event if nothing changed. + if ( + self.mne.t_start == t_start + and self.mne.t_start + self.mne.duration == t_end + ): + return + + if self.mne.is_epochs: + # Translate the time-coordinate in the browser window to the actual + # start/end times of the epochs in the raw file. + epoch_num_start = self._get_epoch_num_from_time( + t_start + self.mne.sampling_period + ) + epoch_num_end = self._get_epoch_num_from_time( + t_end + self.mne.sampling_period + ) + onsets = self.mne.inst.events[:, 0] / self.mne.info["sfreq"] + t_start = onsets[epoch_num_start] + self.mne.inst.tmin + t_end = onsets[epoch_num_end - 1] + self.mne.inst.tmax + else: + # For raw data, we need to take `first_time` into account. + t_start += self.mne.inst.first_time + t_end += self.mne.inst.first_time + + publish(self, TimeBrowse(time_start=t_start, time_end=t_end)) + + def _on_time_browse_event(self, event): + """Respond to the TimeBrowse UI event, update horizontal scrolling.""" + time_start = event.time_start + time_end = event.time_end + + if self.mne.is_epochs: + # Translate the start/end times from the original raw to the indices of the + # epochs being shown, and then to the appropriate start/end times in the + # browser. + events = self.mne.inst.events[self.mne.inst.selection] + onsets = events[:, 0] / self.mne.info["sfreq"] + # Subtract/add one sample to make sure we end on the right side. This is + # needed because of small floating point errors. + epoch_ix_start = np.searchsorted( + onsets, time_start - self.mne.inst.tmin - self.mne.sampling_period + ) + epoch_ix_end = np.searchsorted( + onsets, time_end - self.mne.inst.tmax + self.mne.sampling_period + ) + # Always show at least one epoch. + epoch_ix_start = min(epoch_ix_start, len(self.mne.inst) - 1) + epoch_ix_end = max(epoch_ix_start + 1, epoch_ix_end) + self.mne.n_epochs = epoch_ix_end - epoch_ix_start + + # Compute the browser time period to match the selected epochs. + time_start = self.mne.boundary_times[epoch_ix_start] + width = self.mne.boundary_times[1] - self.mne.boundary_times[0] + time_end = self.mne.boundary_times[epoch_ix_end - 1] + width + else: + # For raw data, we need to take `first_time` into account. + time_start -= self.mne.inst.first_time + time_end -= self.mne.inst.first_time + + # Never show fewer than 3 samples. + min_dur = 3 * np.diff(self.mne.inst.times[:2])[0] + time_end = np.clip(time_end, time_start + min_dur, self.mne.inst.times[-1]) + + # Update browser window. + self.mne.t_start = time_start + self.mne.duration = time_end - time_start + self._update_hscroll() + self._redraw(annotations=True) + + def _on_channel_browse_event(self, event): + """Respond to the ChannelBrowse UI event.""" + old_n_channels = self.mne.n_channels + picks = np.flatnonzero( + np.isin(self.mne.ch_names[self.mne.ch_order], event.channels) + ) + if len(picks) == 0: + return # can't handle the event + if picks.min() == self.mne.ch_start and len(picks) == self.mne.n_channels: + return # no change + + self.mne.ch_start = picks.min() + self.mne.n_channels = len(picks) + self._update_vscroll() + self._update_picks() + if self.mne.n_channels != old_n_channels: + self._update_trace_offsets() + self._redraw(annotations=True) + + def _on_time_change_event(self, event): + """Respond to the TimeChange UI event.""" + if self.mne.is_epochs: + time = np.clip(event.time, self.mne.inst.tmin, self.mne.inst.tmax) + time -= self.mne.inst.tmin + else: + time = event.time - self.mne.inst.first_time + time = np.clip(time, self.mne.inst.times[0], self.mne.inst.times[-1]) + self._show_vline(time) + class MNELineFigure(MNEFigure): """Interactive figure for non-scrolling line plots.""" diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index 256d5741ad3..c3d0636cca8 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -42,176 +42,6 @@ _camel_to_snake = re.compile(r"(?