Skip to content

Commit

Permalink
Merge pull request #1781 from h-mayorquin/binary_recording_limit_get_…
Browse files Browse the repository at this point in the history
…traces_allocation

Make binary recording memmap efficient III (Avoiding spikes by only reserving memory equal to the trace)
  • Loading branch information
samuelgarcia authored Jul 5, 2024
2 parents 0422cfb + c652c05 commit 8a97e10
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 25 deletions.
58 changes: 35 additions & 23 deletions src/spikeinterface/core/binaryrecordingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,25 +166,17 @@ def get_binary_description(self):


class BinaryRecordingSegment(BaseRecordingSegment):
def __init__(self, datfile, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset):
def __init__(self, file_path, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset):
BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start)
self.num_channels = num_channels
self.dtype = np.dtype(dtype)
self.file_offset = file_offset
self.time_axis = time_axis
self.datfile = datfile
self.file = open(self.datfile, "r")
self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_channels * np.dtype(dtype).itemsize)
if self.time_axis == 0:
self.shape = (self.num_samples, self.num_channels)
else:
self.shape = (self.num_channels, self.num_samples)

byte_offset = self.file_offset
dtype_size_bytes = self.dtype.itemsize
data_size_bytes = dtype_size_bytes * self.num_samples * self.num_channels
self.memmap_offset, self.array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY)
self.memmap_length = data_size_bytes + self.array_offset
self.file_path = file_path
self.file = open(self.file_path, "rb")
self.bytes_per_sample = self.num_channels * self.dtype.itemsize
self.data_size_in_bytes = Path(file_path).stat().st_size - file_offset
self.num_samples = self.data_size_in_bytes // self.bytes_per_sample

def get_num_samples(self) -> int:
"""Returns the number of samples in this signal block
Expand All @@ -200,23 +192,43 @@ def get_traces(
end_frame: int | None = None,
channel_indices: list | None = None,
) -> np.ndarray:
length = self.memmap_length
memmap_offset = self.memmap_offset

# Calculate byte offsets for start and end frames
start_byte = self.file_offset + start_frame * self.bytes_per_sample
end_byte = self.file_offset + end_frame * self.bytes_per_sample

# Calculate the length of the data chunk to load into memory
length = end_byte - start_byte

# The mmap offset must be a multiple of mmap.ALLOCATIONGRANULARITY
memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY)
memmap_offset *= mmap.ALLOCATIONGRANULARITY

# Adjust the length so it includes the extra data from rounding down
# the memmap offset to a multiple of ALLOCATIONGRANULARITY
length += start_offset

# Create the mmap object
memmap_obj = mmap.mmap(self.file.fileno(), length=length, access=mmap.ACCESS_READ, offset=memmap_offset)

array = np.ndarray.__new__(
np.ndarray,
shape=self.shape,
# Create a numpy array using the mmap object as the buffer
# Note that the shape must be recalculated based on the new data chunk
if self.time_axis == 0:
shape = ((end_frame - start_frame), self.num_channels)
else:
shape = (self.num_channels, (end_frame - start_frame))

# Now the entire array should correspond to the data between start_frame and end_frame, so we can use it directly
traces = np.ndarray(
shape=shape,
dtype=self.dtype,
buffer=memmap_obj,
order="C",
offset=self.array_offset,
offset=start_offset,
)

if self.time_axis == 1:
array = array.T
traces = traces.T

traces = array[start_frame:end_frame]
if channel_indices is not None:
traces = traces[:, channel_indices]

Expand Down
25 changes: 25 additions & 0 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,28 @@ def retrieve_importing_provenance(a_class):
}

return info


def measure_memory_allocation(measure_in_process: bool = True) -> float:
"""
A local utility to measure memory allocation at a specific point in time.
Can measure either the process resident memory or system wide memory available
Uses psutil package.
Parameters
----------
measure_in_process : bool, True by default
Mesure memory allocation in the current process only, if false then measures at the system
level.
"""
import psutil

if measure_in_process:
process = psutil.Process()
memory = process.memory_info().rss
else:
mem_info = psutil.virtual_memory()
memory = mem_info.total - mem_info.available

return memory
67 changes: 65 additions & 2 deletions src/spikeinterface/core/tests/test_binaryrecordingextractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
import numpy as np
from pathlib import Path

from spikeinterface.core import BinaryRecordingExtractor
from spikeinterface.core.numpyextractors import NumpyRecording
from spikeinterface.core.core_tools import measure_memory_allocation
from spikeinterface.core.generate import NoiseGeneratorRecording


def test_BinaryRecordingExtractor(create_cache_folder):
Expand Down Expand Up @@ -51,15 +54,75 @@ def test_round_trip(tmp_path):
dtype=dtype,
)

# Test for full traces
assert np.allclose(recording.get_traces(), binary_recorder.get_traces())

start_frame = 200
end_frame = 500
# Ttest for a sub-set of the traces
start_frame = 20
end_frame = 40
smaller_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame)
binary_smaller_traces = binary_recorder.get_traces(start_frame=start_frame, end_frame=end_frame)

np.allclose(smaller_traces, binary_smaller_traces)


@pytest.fixture(scope="module")
def folder_with_binary_files(tmpdir_factory):
tmp_path = Path(tmpdir_factory.mktemp("spike_interface_test"))
folder = tmp_path / "test_binary_recording"
num_channels = 32
sampling_frequency = 30_000.0
dtype = "float32"
recording = NoiseGeneratorRecording(
durations=[1.0],
sampling_frequency=sampling_frequency,
num_channels=num_channels,
dtype=dtype,
)
dtype = recording.get_dtype()
recording.save(folder=folder, overwrite=True)

return folder


def test_sequential_reading_of_small_traces(folder_with_binary_files):
# Test that memmap is readed correctly when pointing to specific frames
folder = folder_with_binary_files
num_channels = 32
sampling_frequency = 30_000.0
dtype = "float32"

file_paths = [folder / "traces_cached_seg0.raw"]
recording = BinaryRecordingExtractor(
num_chan=num_channels,
file_paths=file_paths,
sampling_frequency=sampling_frequency,
dtype=dtype,
)

full_traces = recording.get_traces()

# Test for a sub-set of the traces
start_frame = 10
end_frame = 15
small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame)
expected_traces = full_traces[start_frame:end_frame, :]
assert np.allclose(small_traces, expected_traces)

# Test for a sub-set of the traces
start_frame = 1000
end_frame = 1100
small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame)
expected_traces = full_traces[start_frame:end_frame, :]
assert np.allclose(small_traces, expected_traces)

# Test for a sub-set of the traces
start_frame = 10_000
end_frame = 11_000
small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame)
expected_traces = full_traces[start_frame:end_frame, :]
assert np.allclose(small_traces, expected_traces)


if __name__ == "__main__":
test_BinaryRecordingExtractor()

0 comments on commit 8a97e10

Please sign in to comment.