Skip to content

Commit

Permalink
Fix t_starts not propagated to save memory.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jul 1, 2024
1 parent a68b68d commit 984aaa5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,11 +566,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
if kwargs.get("sharedmem", True):
from .numpyextractors import SharedMemoryRecording

cached = SharedMemoryRecording.from_recording(self, **job_kwargs)
cached = SharedMemoryRecording.from_recording(self, t_starts=t_starts, **job_kwargs)
else:
from spikeinterface.core import NumpyRecording

cached = NumpyRecording.from_recording(self, **job_kwargs)
cached = NumpyRecording.from_recording(self, t_starts=t_starts, **job_kwargs)

elif format == "zarr":
from .zarrextractors import ZarrRecordingExtractor
Expand Down
9 changes: 5 additions & 4 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
}

@staticmethod
def from_recording(source_recording, **job_kwargs):
def from_recording(source_recording, t_starts=None, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs)
if shms[0] is not None:
# if the computation was done in parallel then traces_list is shared array
Expand All @@ -99,9 +99,10 @@ def from_recording(source_recording, **job_kwargs):
recording = NumpyRecording(
traces_list,
source_recording.get_sampling_frequency(),
t_starts=None,
t_starts=t_starts,
channel_ids=source_recording.channel_ids,
)
return recording


class NumpyRecordingSegment(BaseRecordingSegment):
Expand Down Expand Up @@ -211,7 +212,7 @@ def __del__(self):
shm.unlink()

@staticmethod
def from_recording(source_recording, **job_kwargs):
def from_recording(source_recording, t_starts=None, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs)

# TODO later : propagte t_starts ?
Expand All @@ -222,7 +223,7 @@ def from_recording(source_recording, **job_kwargs):
dtype=source_recording.dtype,
sampling_frequency=source_recording.sampling_frequency,
channel_ids=source_recording.channel_ids,
t_starts=None,
t_starts=t_starts,
main_shm_owner=True,
)

Expand Down

0 comments on commit 984aaa5

Please sign in to comment.