Skip to content

Commit

Permalink
better get_traces check
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Jun 25, 2024
1 parent d963aa2 commit 458aa94
Show file tree
Hide file tree
Showing 40 changed files with 65 additions and 41 deletions.
24 changes: 24 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,7 @@ def get_num_samples(self) -> int:
# must be implemented in subclass
raise NotImplementedError


def get_traces(
self,
start_frame: int | None = None,
Expand All @@ -884,5 +885,28 @@ def get_traces(
traces : np.ndarray
Array of traces, num_samples x num_channels
"""
# @alessio
start_frame = int(start_frame) if start_frame is not None else 0
num_samples = self.get_num_samples()
end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples

# @ramon @alessio @zach @joe @paul
# https://github.com/SpikeInterface/spikeinterface/issues/1989
# here we can implement the strick mode
# now it is not activate yet
strict_mode = False
if strict_mode:
if start_frame < 0:
raise ValueError(f"get_traces() : wrong start_frame {start_frame}, must be positive. You should stop doing this otherwise you will be excommunicated")
if end_frame > num_samples:
raise ValueError(f"get_traces() : wrong end_frame {end_frame}, must be maximumnum samples {num_samples}. You should stop doing this otherwise you will be excommunicated")
if end_frame >= start_frame:
raise ValueError(f"get_traces() : wrong end_frame/start_frame : {start_frame} < {end_frame}. You should stop doing this otherwise you will be excommunicated")

self._get_traces(start_frame, end_frame, channel_indices)



def _get_traces(self, start_frame, end_frame, channel_indices) -> np.ndarray:
# must be implemented in subclass
raise NotImplementedError
2 changes: 1 addition & 1 deletion src/spikeinterface/core/binaryrecordingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_num_samples(self) -> int:
"""
return self.num_samples

def get_traces(
def _get_traces(
self,
start_frame: int | None = None,
end_frame: int | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_num_samples(self) -> int:
# num samples are all the same
return self._parent_segments[0].get_num_samples()

def get_traces(
def _get_traces(
self,
start_frame: int | None = None,
end_frame: int | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/channelslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, parent_recording_segment, parent_channel_indices):
def get_num_samples(self) -> int:
return self._parent_recording_segment.get_num_samples()

def get_traces(
def _get_traces(
self,
start_frame: int | None = None,
end_frame: int | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/frameslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, parent_recording_segment, start_frame, end_frame):
def get_num_samples(self) -> int:
return self.end_frame - self.start_frame

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
parent_start = self.start_frame + start_frame
parent_end = self.start_frame + end_frame
traces = self._parent_recording_segment.get_traces(
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ def __init__(
def get_num_samples(self) -> int:
return self.num_samples

def get_traces(
def _get_traces(
self,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
Expand Down Expand Up @@ -1801,7 +1801,7 @@ def __init__(
self.parent_recording = parent_recording_segment
self.num_samples = parent_recording_segment.get_num_frames() if num_samples is None else num_samples

def get_traces(
def _get_traces(
self,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, traces, sampling_frequency, t_start):
def get_num_samples(self) -> int:
return self.num_samples

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
traces = self._traces[start_frame:end_frame, :]
if channel_indices is not None:
traces = traces[:, channel_indices]
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/segmentutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(self, parent_segments, sampling_frequency, ignore_times=True):
def get_num_samples(self):
return self.total_length

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
# # Ensures that we won't request invalid segment indices
if (start_frame >= self.get_num_samples()) or (end_frame <= start_frame):
# Return (0 * num_channels) array of correct dtype
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_num_samples(self) -> int:
"""
return self._timeseries.shape[0]

def get_traces(
def _get_traces(
self,
start_frame: int | None = None,
end_frame: int | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/cbin_ibl.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, cbuffer, sampling_frequency, load_sync_channel):
def get_num_samples(self):
return self._cbuffer.shape[0]

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
channel_indices = slice(None)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/iblextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(self, file_streamer, load_sync_channel: bool = False):
def get_num_samples(self):
return self._file_streamer.ns

def get_traces(self, start_frame: int, end_frame: int, channel_indices):
def _get_traces(self, start_frame: int, end_frame: int, channel_indices):
if channel_indices is None:
channel_indices = slice(None)
traces = self._file_streamer.read(nsel=slice(start_frame, end_frame), volts=False)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/mcsh5extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, rf, stream_id, num_frames, sampling_frequency):
def get_num_samples(self):
return self._num_samples

def get_traces(self, start_frame=None, end_frame=None, channel_indices=None):
def _get_traces(self, start_frame=None, end_frame=None, channel_indices=None):
if isinstance(channel_indices, slice):
traces = self._stream.get("ChannelData")[channel_indices, start_frame:end_frame].T
else:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/mdaextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_num_samples(self):
"""
return self._num_samples

def get_traces(
def _get_traces(
self,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def get_num_samples(self):

return int(num_samples)

def get_traces(
def _get_traces(
self,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ def get_num_samples(self):
"""
return self._num_samples

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
electrical_series_data = self.electrical_series_data
if electrical_series_data.ndim == 1:
traces = electrical_series_data[start_frame:end_frame][:, np.newaxis]
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/sinapsrecordingextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(self, rf, num_frames, sampling_frequency, num_bits):
def get_num_samples(self):
return self._num_samples

def get_traces(self, start_frame=None, end_frame=None, channel_indices=None):
def _get_traces(self, start_frame=None, end_frame=None, channel_indices=None):
if isinstance(channel_indices, slice):
traces = self._stream.get("FilteredData")[channel_indices, start_frame:end_frame].T
else:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def __init__(
self.displacement_indices = displacement_indices
self.templates_array_moved = templates_array_moved

def get_traces(
def _get_traces(
self,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.dtype = dtype
self.round = round

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
channel_indices = slice(None)
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
def get_num_samples(self):
return self.parent_recording_segment.get_num_samples()

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
parent_traces = self.parent_recording_segment.get_traces(
start_frame=start_frame,
end_frame=end_frame,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/basepreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ def __init__(self, parent_recording_segment):
def get_num_samples(self):
return self.parent_recording_segment.get_num_samples()

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
raise NotImplementedError
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self, parent_recording_segment, a_min, value_min, a_max, value_max)
self.a_max = a_max
self.value_max = value_max

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
traces = traces.copy()

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
self.dtype = dtype
self.operator_func = operator = np.mean if self.operator == "average" else np.median

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
# Let's do the case with group_indices equal None as that is easy
if self.group_indices is None:
# We need all the channels to calculate the reference
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_num_samples(self):
assert self._decimation_offset < parent_n_samp # Sanity check (already enforced). Formula changes otherwise
return int(np.ceil((parent_n_samp - self._decimation_offset) / self._decimation_factor))

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
# Account for offset and end when querying parent traces
parent_start_frame = self._decimation_offset + start_frame * self._decimation_factor
parent_end_frame = parent_start_frame + (end_frame - start_frame) * self._decimation_factor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
self.desired_shape = desired_shape
self.predict_workers = predict_workers

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
from .generators import SpikeInterfaceRecordingSegmentGenerator

n_frames = self.parent_recording_segment.get_num_samples()
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/directional_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
# so that geom_other_dims[i] == unique_pos_other_dims[column_inds[i]]
self.unique_pos_other_dims, self.column_inds = np.unique(geom_other_dims, axis=0, return_inverse=True)

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
parent_traces = self.parent_recording_segment.get_traces(
start_frame=start_frame,
end_frame=end_frame,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype,
self.add_reflect_padding = add_reflect_padding
self.dtype = dtype

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
self.parent_recording_segment,
start_frame,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/filter_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
sigmas.append(sf / (2 * np.pi * freq_max))
self.margin = 1 + int(max(sigmas) * margin_sd)

def get_traces(
def _get_traces(
self,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/filter_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, parent_recording_segment, executor, margin):
self.executor = executor
self.margin = margin

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
assert start_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size"
assert end_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(
self.sos_filter = sos_filter
self.dtype = dtype

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
channel_indices = slice(None)
if self.window is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, parent_recording_segment, good_channel_indices, bad_channel_i
self._bad_channel_indices = bad_channel_indices
self._weights = weights

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
channel_indices = slice(None)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/normalize_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, parent_recording_segment, gain, offset, dtype):
self.offset = offset
self._dtype = dtype

def get_traces(self, start_frame, end_frame, channel_indices) -> np.ndarray:
def _get_traces(self, start_frame, end_frame, channel_indices) -> np.ndarray:
# TODO when we are sure that BaseExtractors get_traces allocate their own buffer instead of just passing
# It along we should remove copies in preprocessors including the one in the next line

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/phase_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, parent_recording_segment, sample_shifts, margin, dtype, tmp_d
self.dtype = dtype
self.tmp_dtype = tmp_dtype

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
channel_indices = slice(None)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/rectify.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class RectifyRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment):
BasePreprocessorSegment.__init__(self, parent_recording_segment)

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
return np.abs(traces)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/remove_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def __init__(
self.time_pad = time_pad
self.sparsity = sparsity

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
if self.mode in ["average", "median"]:
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None))
else:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
def get_num_samples(self):
return int(self._parent_segment.get_num_samples() / self._parent_rate * self.sampling_frequency)

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
# get parent traces with margin
parent_start_frame, parent_end_frame = [
int((frame / self.sampling_frequency) * self._parent_rate) for frame in [start_frame, end_frame]
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/silence_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg
self.seg_index = seg_index
self.noise_generator = noise_generator

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
traces = traces.copy()

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/unsigned_to_signed.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, parent_recording_segment, dtype_signed, bit_depth):
self.dtype_signed = dtype_signed
self.bit_depth = bit_depth

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
channel_indices = slice(None)
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self, parent_recording_segment, W, M, dtype, int_scale):
self.dtype = dtype
self.int_scale = int_scale

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None))
traces_dtype = traces.dtype
# if uint --> force int
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/zero_channel_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(

super().__init__(parent_recording_segment=recording_segment)

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
# This contains the padded elements by default and we add the original traces if necessary
trace_size = end_frame - start_frame
if isinstance(channel_indices, (np.ndarray, list)):
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(self, recording_segment: BaseRecordingSegment, num_channels: int, c
self.num_channels = num_channels
self.channel_mapping = channel_mapping

def get_traces(self, start_frame, end_frame, channel_indices):
def _get_traces(self, start_frame, end_frame, channel_indices):
traces = np.zeros((end_frame - start_frame, self.num_channels))
traces[:, self.channel_mapping] = self.parent_recording_segment.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=self.channel_mapping
Expand Down
Loading

0 comments on commit 458aa94

Please sign in to comment.