Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct handling of time in plot_traces #3393

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/spikeinterface/core/frameslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def __init__(self, parent_recording, start_frame=None, end_frame=None):
if start_frame is None:
start_frame = 0
else:
assert 0 <= start_frame < parent_size
assert (
0 <= start_frame < parent_size
), f"`start_frame` must be fewer than number of samples in parent: {parent_size}"

if end_frame is None:
end_frame = parent_size
Expand Down
46 changes: 28 additions & 18 deletions src/spikeinterface/widgets/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ def __init__(

if not rec0.has_time_vector(segment_index=segment_index):
times = None
t_start = 0
t_end = rec0.get_duration(segment_index=segment_index)
t_start = rec0.sample_index_to_time(0, segment_index=segment_index)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this first conditional can be completely removed now, as get_times() will always return a time array that will incorporate whether either a time vector, t_start attribute, or no time information set on the recording.

If times is None, then the times will be generated on the fly in _get_trace_list but this uses essentially the same code that is now implemented centrally in get_times().

I think this is also the case here. If the second case in the conditional can also be used and the first case removed, I think that times=None option could be removed from _get_trace_list.

So, basically this would centralise some of the time-related computations, away from the widget and out to the recording. But, maybe it will be necessary to check the frame computations are handled the same as there is some clipping done to them in one of the conditionals I linked. I think they should work as before as the times are converted already to frames with the new methods here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done to avoid loading the time vector if not needed. In the _get_traces_list, only a local time vector is generated on the fly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice that makes sense

t_end = rec0.sample_index_to_time(
rec0.get_num_samples(segment_index=segment_index), segment_index=segment_index
)
else:
times = rec0.get_times(segment_index=segment_index)
t_start = times[0]
Expand All @@ -149,6 +151,11 @@ def __init__(
)
time_range[1] = t_end

if time_range[0] < t_start or time_range[1] < t_start:
raise ValueError(f"All `time_range` values must be greater than {t_start}")
if time_range[1] <= time_range[0]:
raise ValueError("`time_range[1]` must be greater than `time_range[0]`")

assert mode in ("auto", "line", "map"), 'Mode must be one of "auto","line", "map"'
if mode == "auto":
if len(channel_ids) <= 64:
Expand All @@ -158,7 +165,7 @@ def __init__(
mode = mode
cmap = cmap

times_in_range, list_traces, frame_range, channel_ids = _get_trace_list(
times_in_range, list_traces, sample_range, channel_ids = _get_trace_list(
recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled, times=times
)

Expand Down Expand Up @@ -250,7 +257,7 @@ def __init__(
channel_ids=channel_ids,
channel_locations=channel_locations,
time_range=time_range,
frame_range=frame_range,
sample_range=sample_range,
times_in_range=times_in_range,
layer_keys=layer_keys,
list_traces=list_traces,
Expand Down Expand Up @@ -536,7 +543,7 @@ def _retrieve_traces(self, change=None):
time_range = np.array([times[start_frame], times[end_frame]])

self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()}
times_in_range, list_traces, frame_range, channel_ids = _get_trace_list(
times_in_range, list_traces, sample_range, channel_ids = _get_trace_list(
self._selected_recordings,
channel_ids,
time_range,
Expand All @@ -549,7 +556,7 @@ def _retrieve_traces(self, change=None):
self._list_traces = list_traces
self._times_in_range = times_in_range
self._time_range = time_range
self._frame_range = (start_frame, end_frame)
self._sample_range = (start_frame, end_frame)
self._segment_index = segment_index

self._update_plot()
Expand All @@ -562,7 +569,7 @@ def _update_plot(self, change=None):
layer_keys = self._get_layers()

data_plot["mode"] = mode
data_plot["frame_range"] = self._frame_range
data_plot["sample_range"] = self._sample_range
data_plot["time_range"] = self._time_range
if self.colorbar.value:
data_plot["with_colorbar"] = True
Expand Down Expand Up @@ -673,26 +680,29 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_s
assert all(
rec.has_scaleable_traces() for rec in recordings.values()
), "Some recording layers do not have scaled traces. Use `return_scaled=False`"
sample_range = np.array(
[
rec0.time_to_sample_index(time_range[0], segment_index=segment_index),
rec0.time_to_sample_index(time_range[1], segment_index=segment_index),
]
)
if times is not None:
frame_range = np.searchsorted(times, time_range)
times = times[frame_range[0] : frame_range[1]]
times = times[sample_range[0] : sample_range[1]]
else:
frame_range = (time_range * fs).astype("int64", copy=False)
a_max = rec0.get_num_frames(segment_index=segment_index)
frame_range = np.clip(frame_range, 0, a_max)
time_range = frame_range / fs
times = np.arange(frame_range[0], frame_range[1]) / fs
num_samples = rec0.get_num_samples(segment_index=segment_index)
sample_range = np.clip(sample_range, 0, num_samples)
times = np.arange(sample_range[0], sample_range[1]) / fs

list_traces = []
for rec_name, rec in recordings.items():
for _, rec in recordings.items():
traces = rec.get_traces(
segment_index=segment_index,
channel_ids=channel_ids,
start_frame=frame_range[0],
end_frame=frame_range[1],
start_frame=sample_range[0],
end_frame=sample_range[1],
return_scaled=return_scaled,
)

list_traces.append(traces)

return times, list_traces, frame_range, channel_ids
return times, list_traces, sample_range, channel_ids
Loading