From bde2f7d18939588e3e6275ff3424a836186dcdde Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 5 Mar 2024 12:17:19 +0100 Subject: [PATCH] Handle negative gain for neo rawio to make smarter user eperience. --- .../neoextractors/neobaseextractor.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py index a5ed72d1c0..f9b7829e98 100644 --- a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py +++ b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py @@ -240,7 +240,7 @@ def __init__( chan_ids = signal_channels["id"] sampling_frequency = self.neo_reader.get_signal_sampling_rate(stream_index=self.stream_index) - dtype = signal_channels["dtype"][0] + dtype = np.dtype(signal_channels["dtype"][0]) BaseRecording.__init__(self, sampling_frequency, chan_ids, dtype) self.extra_requirements.append("neo") @@ -248,6 +248,14 @@ def __init__( gains = signal_channels["gain"] offsets = signal_channels["offset"] + if dtype.kind == "i" and np.all(gains <0) and np.all(offsets == 0): + # special hack when all channel have negative gain: we put back the gain positive + # this help the end user experience + self.inverted_gain = True + gains = -gains + else: + self.inverted_gain = False + units = signal_channels["units"] # mark that units are V, mV or uV @@ -288,7 +296,7 @@ def __init__( nseg = self.neo_reader.segment_count(block_index=self.block_index) for segment_index in range(nseg): - rec_segment = NeoRecordingSegment(self.neo_reader, self.block_index, segment_index, self.stream_index) + rec_segment = NeoRecordingSegment(self.neo_reader, self.block_index, segment_index, self.stream_index, self.inverted_gain) self.add_recording_segment(rec_segment) self._kwargs.update(kwargs) @@ -301,7 +309,7 @@ def get_num_blocks(cls, *args, **kwargs): class NeoRecordingSegment(BaseRecordingSegment): - def __init__(self, neo_reader, block_index, segment_index, stream_index): + def __init__(self, neo_reader, block_index, segment_index, stream_index, inverted_gain): sampling_frequency = neo_reader.get_signal_sampling_rate(stream_index=stream_index) t_start = neo_reader.get_signal_t_start(block_index, segment_index, stream_index=stream_index) BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start) @@ -309,6 +317,7 @@ def __init__(self, neo_reader, block_index, segment_index, stream_index): self.segment_index = segment_index self.stream_index = stream_index self.block_index = block_index + self.inverted_gain = inverted_gain def get_num_samples(self): num_samples = self.neo_reader.get_signal_size( @@ -331,6 +340,8 @@ def get_traces( stream_index=self.stream_index, channel_indexes=channel_indices, ) + if self.inverted_gain: + raw_traces = -raw_traces return raw_traces