Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose t_start in BaseRecording #3117

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,26 @@ def has_time_vector(self, segment_index=None):
return d["time_vector"] is not None

def set_times(self, times, segment_index=None, with_warning=True):
"""Set times for a recording segment.
"""Set times for a recording segment. Any existing times
will be overwritten.

Times can be manually set on the recording segment. If times are
not set, the sample index and sampling frequency are used to
calculate time. Otherwise, `t_start` or `time_vector` can be
provided:

`t_start` - the start time for the segment. The times for
this recording segment will be calculated as
t_start + sample_index * (1 / sampling_frequency)

`time_vector` - A vector of length segment.get_num_samples()
that holds the exact time for each sample in the recording.

Parameters
----------
times : 1d np.array
The time vector
times : int | float | 1d np.array
Copy link
Collaborator

@h-mayorquin h-mayorquin Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally would prefer to not overload this function and create a new one instead. set_t_start.

What are the advantages of overloading this? How are you thinking about it?

But I think what kind of API we should have will become clear once I understood how are you envisioning this to be used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I couldn't decide on this vs. separate functions and @samuelgarcia suggested this approach. On balance I think I prefer it as t_start and times_vector are mutually exclusive ways of setting custom times, so it makes sense to change in one place. It would be slightly strange to call set_t_start() and this removes the time_vector attribute. But I'm not sure either way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Thanks for sharing, I will add this as something to discuss at the meeting.

Copy link
Collaborator

@h-mayorquin h-mayorquin Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that my take is the following:

We use set_times because the sampling recording is slightly irregular when we want more precision. This is a good name for what the function does, it has a clear purpose and semantics.

Why we would need to set t_start?

The use case that I can think off is that we would like to shift all the recording to the right or the left on time. But if that is the use case I would be better to have a method that shift the recording times and works independently of whether times are handled internally with t_start and sampling frequency or a time vector.
In the first case, you shift t_start (in most of the cases from 0) and in the second you shift the time vector.

If it is not for shifting I can't think on other use of setting t_start

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the use case would be if you had separate sessions in a single day, for example a session, 10 minute break to change some equiptment, and another recording session. If these sessions are held as different segments on a recording (or, as separate recordings) the researcher may want to hold the true recording times (e.g. session 1 started at 1 pm, session two at 1.30 pm).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, change my answer below what I think this type of case is not well served with the current implementation.

If `int` or `float`, this is the `t_start` for the segment,
otherwise, it is the time vector.
segment_index : int or None, default: None
The segment index (required for multi-segment)
with_warning : bool, default: True
Expand All @@ -472,11 +486,18 @@ def set_times(self, times, segment_index=None, with_warning=True):
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]

assert times.ndim == 1, "Time must have ndim=1"
assert rs.get_num_samples() == times.shape[0], "times have wrong shape"
if isinstance(times, float) or isinstance(times, int):
rs.t_start = times
rs.time_vector = None
elif isinstance(times, np.ndarray):

rs.t_start = None
rs.time_vector = times.astype("float64", copy=False)
assert times.ndim == 1, "Time must have ndim=1"
assert rs.get_num_samples() == times.shape[0], "times have wrong shape"

rs.t_start = None
rs.time_vector = times.astype("float64", copy=False)
else:
raise TypeError("`times` must be an integer / float (`t_start`) or " "numpy array (`time_vector`).")

if with_warning:
warnings.warn(
Expand Down
319 changes: 319 additions & 0 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,328 @@
import copy

import pytest
import numpy as np

from spikeinterface.core import generate_recording, generate_sorting
import spikeinterface.full as si


class TestTimeHandling:

# Fixtures #####
@pytest.fixture(scope="session")
def raw_recording(self):
"""
A three-segment raw recording without times added.
"""
durations = [10, 15, 20]
recording = generate_recording(num_channels=4, durations=durations)
return recording

@pytest.fixture(scope="session")
def time_vector_recording(self, raw_recording):
"""
Add time vectors to the recording, returning the
raw recording, recording with time vectors added to
segments, and list a the time vectors added to the recording.
"""
return self._get_time_vector_recording(raw_recording)

@pytest.fixture(scope="session")
def t_start_recording(self, raw_recording):
"""
Add a t_starts to the recording, returning the
raw recording, recording with t_starts added to segments,
and a list of the time vectors generated from adding the
t_start to the recording times.
"""
return self._get_t_start_recording(raw_recording)

def _get_time_vector_recording(self, raw_recording):
"""
Loop through all recording segments, adding a different time
vector to each segment. The time vector is the original times with
a t_start and irregularly spaced offsets to mimic irregularly
spaced timeseries data. Return the original recording,
recoridng with time vectors added and list including the added time vectors.
"""
times_recording = copy.deepcopy(raw_recording)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have clone for this as an extractor method but if you really require this, why make the raw recording fixture per session?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benefit of the raw_recording fixture is that durations only needs to be defined once, then copied as set_times() is in place. But I agree it is a lot of indirection and it is probably more readable to incorporate into the individual fixtures, possibly with DURATIONS=[...] set at the top of the script?

all_time_vectors = []
for segment_index in range(raw_recording.get_num_segments()):

t_start = segment_index + 1 * 100
offsets = np.arange(times_recording.get_num_samples(segment_index)) * (
1 / times_recording.get_sampling_frequency()
)
time_vector = t_start + times_recording.get_times(segment_index) + offsets

all_time_vectors.append(time_vector)
times_recording.set_times(times=time_vector, segment_index=segment_index)

assert np.array_equal(
times_recording._recording_segments[segment_index].time_vector,
time_vector,
), "time_vector was not properly set during test setup"

return (raw_recording, times_recording, all_time_vectors)

def _get_t_start_recording(self, raw_recording):
"""
For each segment in the recording, add a different `t_start`.
Return a list of time vectors generating from the recording times
+ the t_starts.
"""
t_start_recording = copy.deepcopy(raw_recording)

all_t_starts = []
for segment_index in range(raw_recording.get_num_segments()):

t_start = (segment_index + 1) * 100

all_t_starts.append(t_start + t_start_recording.get_times(segment_index))
t_start_recording.set_times(times=t_start, segment_index=segment_index)

assert np.array_equal(
t_start_recording._recording_segments[segment_index].t_start,
t_start,
), "t_start was not properly set during test setup"

return (raw_recording, t_start_recording, all_t_starts)

def _get_fixture_data(self, request, fixture_name):
"""
A convenience function to get the data from a fixture
based on the name. This is used to allow parameterising
tests across fixtures.
"""
time_recording_fixture = request.getfixturevalue(fixture_name)
raw_recording, times_recording, all_times = time_recording_fixture
return (raw_recording, times_recording, all_times)

# Tests #####
def test_has_time_vector(self, time_vector_recording):
"""
Test the `has_time_vector` function returns `False` before
a time vector is added and `True` afterwards.
"""
raw_recording, times_recording, _ = time_vector_recording

for segment_idx in range(raw_recording.get_num_segments()):

assert raw_recording.has_time_vector(segment_idx) is False
assert times_recording.has_time_vector(segment_idx) is True

def test_get_durations(self, time_vector_recording, t_start_recording):
"""
Test the `get_durations` functions that return the total duration
for a segment. Test that it is correct after adding both `t_start`
or `time_vector` to the recording.
"""
raw_recording, tvector_recording, all_time_vectors = time_vector_recording
_, tstart_recording, all_t_starts = t_start_recording

ts = 1 / raw_recording.get_sampling_frequency()

all_raw_durations = []
all_vector_durations = []
for segment_index in range(raw_recording.get_num_segments()):

# Test before `t_start` and `t_start` (`t_start` is just an offset,
# should not affect duration).
raw_duration = all_t_starts[segment_index][-1] - all_t_starts[segment_index][0] + ts

assert np.isclose(raw_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)
assert np.isclose(tstart_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)

# Test the duration from the time vector.
vector_duration = all_time_vectors[segment_index][-1] - all_time_vectors[segment_index][0] + ts

assert tvector_recording.get_duration(segment_index) == vector_duration

all_raw_durations.append(raw_duration)
all_vector_durations.append(vector_duration)

# Finally test the total recording duration
assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8)
assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8)

@pytest.mark.parametrize("mode", ["binary", "zarr"])
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_path):
"""
Test `t_start` or `time_vector` is propagated to a saved recording,
by saving, reloading, and checking times are correct.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

folder_name = "recording"
recording_cache = times_recording.save(format=mode, folder=tmp_path / folder_name)

if mode == "zarr":
folder_name += ".zarr"
recording_load = si.load_extractor(tmp_path / folder_name)

self._check_times_match(recording_cache, all_times)
self._check_times_match(recording_load, all_times)

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
@pytest.mark.parametrize("sharedmem", [True, False])
def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem):
"""
Test t_start and time_vector are propagated to recording saved into memory.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

recording_load = times_recording.save(format="memory", sharedmem=sharedmem)

self._check_times_match(recording_load, all_times)

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_time_propagated_to_select_segments(self, request, fixture_name):
"""
Test that when `recording.select_segments()` is used, the times
are propagated to the new recoridng object.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

for segment_index in range(times_recording.get_num_segments()):
segment = times_recording.select_segments(segment_index)
assert np.array_equal(segment.get_times(), all_times[segment_index])

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_times_propagated_to_sorting(self, request, fixture_name):
"""
Check that when attached to a sorting object, the times are propagated
to the object. This means that all spike times should respect the
`t_start` or `time_vector` added.
"""
raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name)
sorting = self._get_sorting_with_recording_attached(
recording_for_durations=raw_recording, recording_to_attach=times_recording
)
for segment_index in range(raw_recording.get_num_segments()):

if fixture_name == "time_vector_recording":
assert sorting.has_time_vector(segment_index=segment_index)

self._check_spike_times_are_correct(sorting, times_recording, segment_index)

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_time_sample_converters(self, request, fixture_name):
"""
Test the `recording.sample_time_to_index` and
`recording.time_to_sample_index` convenience functions.
"""
raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name)
with pytest.raises(ValueError) as e:
times_recording.sample_index_to_time(0)
assert "Provide 'segment_index'" in str(e)

for segment_index in range(times_recording.get_num_segments()):

sample_index = np.random.randint(low=0, high=times_recording.get_num_samples(segment_index))
time_ = times_recording.sample_index_to_time(sample_index, segment_index=segment_index)

assert time_ == all_times[segment_index][sample_index]

new_sample_index = times_recording.time_to_sample_index(time_, segment_index=segment_index)

assert new_sample_index == sample_index

@pytest.mark.parametrize("time_type", ["time_vector", "t_start"])
@pytest.mark.parametrize("bounds", ["start", "middle", "end"])
def test_slice_recording(self, time_type, bounds):
"""
Test after `frame_slice` and `time_slice` a recording or
sorting (for `frame_slice`), the recording times are
correct with respect to the set `t_start` or `time_vector`.
"""
raw_recording = generate_recording(num_channels=4, durations=[10])

if time_type == "time_vector":
raw_recording, times_recording, all_times = self._get_time_vector_recording(raw_recording)
else:
raw_recording, times_recording, all_times = self._get_t_start_recording(raw_recording)

sorting = self._get_sorting_with_recording_attached(
recording_for_durations=raw_recording, recording_to_attach=times_recording
)

# Take some different times, including min and max bounds of
# the recording, and some arbitaray times in the middle (20% and 80%).
if bounds == "start":
start_frame = 0
end_frame = int(times_recording.get_num_samples(0) * 0.8)
elif bounds == "end":
start_frame = int(times_recording.get_num_samples(0) * 0.2)
end_frame = times_recording.get_num_samples(0) - 1
elif bounds == "middle":
start_frame = int(times_recording.get_num_samples(0) * 0.2)
end_frame = int(times_recording.get_num_samples(0) * 0.8)

# Slice the recording and get the new times are correct
rec_frame_slice = times_recording.frame_slice(start_frame=start_frame, end_frame=end_frame)
sort_frame_slice = sorting.frame_slice(start_frame=start_frame, end_frame=end_frame)

assert np.allclose(rec_frame_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)

self._check_spike_times_are_correct(sort_frame_slice, rec_frame_slice, segment_index=0)

# Test `time_slice`
start_time = times_recording.sample_index_to_time(start_frame)
end_time = times_recording.sample_index_to_time(end_frame)

rec_slice = times_recording.time_slice(start_time=start_time, end_time=end_time)

assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)

# Helpers ####
def _check_times_match(self, recording, all_times):
"""
For every segment in a recording, check the `get_times()`
match the expected times in the list of time vectors, `all_times`.
"""
for segment_index in range(recording.get_num_segments()):
assert np.array_equal(recording.get_times(segment_index), all_times[segment_index])

def _check_spike_times_are_correct(self, sorting, times_recording, segment_index):
"""
For every unit in the `sorting`, for a particular segment, check that
the unit times match the times of the original recording as
retrieved with `get_times()`.
"""
for unit_id in sorting.get_unit_ids():
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True)
spike_indexes = sorting.get_unit_spike_train(unit_id, segment_index=segment_index)
rec_times = times_recording.get_times(segment_index=segment_index)

assert np.array_equal(
spike_times,
rec_times[spike_indexes],
)

def _get_sorting_with_recording_attached(self, recording_for_durations, recording_to_attach):
"""
Convenience function to create a sorting object with
a recording attached. Typically use the raw recordings
for the durations of which to make the sorter, as
the generate_sorter is not setup to handle the
(strange) edge case of the irregularly spaced
test time vectors.
"""
durations = [
recording_for_durations.get_duration(idx) for idx in range(recording_for_durations.get_num_segments())
]

sorting = generate_sorting(num_units=10, durations=durations)

sorting.register_recording(recording_to_attach)
assert sorting.has_recording()

return sorting


# TODO: deprecate original implementations ###
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: this was messing up the diff so left for the end.

def test_time_handling(create_cache_folder):
cache_folder = create_cache_folder
durations = [[10], [10, 5]]
Expand Down
Loading