Skip to content

Commit

Permalink
Handle negative gain for neo rawio to make smarter user eperience.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Mar 5, 2024
1 parent 25bb836 commit bde2f7d
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/spikeinterface/extractors/neoextractors/neobaseextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,22 @@ def __init__(
chan_ids = signal_channels["id"]

sampling_frequency = self.neo_reader.get_signal_sampling_rate(stream_index=self.stream_index)
dtype = signal_channels["dtype"][0]
dtype = np.dtype(signal_channels["dtype"][0])
BaseRecording.__init__(self, sampling_frequency, chan_ids, dtype)
self.extra_requirements.append("neo")

# find the gain to uV
gains = signal_channels["gain"]
offsets = signal_channels["offset"]

if dtype.kind == "i" and np.all(gains <0) and np.all(offsets == 0):
# special hack when all channel have negative gain: we put back the gain positive
# this help the end user experience
self.inverted_gain = True
gains = -gains
else:
self.inverted_gain = False

units = signal_channels["units"]

# mark that units are V, mV or uV
Expand Down Expand Up @@ -288,7 +296,7 @@ def __init__(

nseg = self.neo_reader.segment_count(block_index=self.block_index)
for segment_index in range(nseg):
rec_segment = NeoRecordingSegment(self.neo_reader, self.block_index, segment_index, self.stream_index)
rec_segment = NeoRecordingSegment(self.neo_reader, self.block_index, segment_index, self.stream_index, self.inverted_gain)
self.add_recording_segment(rec_segment)

self._kwargs.update(kwargs)
Expand All @@ -301,14 +309,15 @@ def get_num_blocks(cls, *args, **kwargs):


class NeoRecordingSegment(BaseRecordingSegment):
def __init__(self, neo_reader, block_index, segment_index, stream_index):
def __init__(self, neo_reader, block_index, segment_index, stream_index, inverted_gain):
sampling_frequency = neo_reader.get_signal_sampling_rate(stream_index=stream_index)
t_start = neo_reader.get_signal_t_start(block_index, segment_index, stream_index=stream_index)
BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start)
self.neo_reader = neo_reader
self.segment_index = segment_index
self.stream_index = stream_index
self.block_index = block_index
self.inverted_gain = inverted_gain

def get_num_samples(self):
num_samples = self.neo_reader.get_signal_size(
Expand All @@ -331,6 +340,8 @@ def get_traces(
stream_index=self.stream_index,
channel_indexes=channel_indices,
)
if self.inverted_gain:
raw_traces = -raw_traces
return raw_traces


Expand Down

0 comments on commit bde2f7d

Please sign in to comment.