Skip to content

Commit

Permalink
Merge pull request #3053 from h-mayorquin/add_scale_to_uV_preprocessing
Browse files Browse the repository at this point in the history
Add `scale_to_uV` preprocessing
  • Loading branch information
samuelgarcia authored Jul 3, 2024
2 parents 203f778 + 0b5c635 commit 5a7d890
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/spikeinterface/preprocessing/preprocessinglist.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
CenterRecording,
center,
)
from .scale import scale_to_uV

from .whiten import WhitenRecording, whiten, compute_whitening_matrix
from .rectify import RectifyRecording, rectify
from .clip import BlankSaturationRecording, blank_staturation, ClipRecording, clip
Expand Down
43 changes: 43 additions & 0 deletions src/spikeinterface/preprocessing/scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import numpy as np

from spikeinterface.core import BaseRecording
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor


def scale_to_uV(recording: BasePreprocessor) -> BasePreprocessor:
"""
Scale raw traces to microvolts (µV).
This preprocessor uses the channel-specific gain and offset information
stored in the recording extractor to convert the raw traces to µV units.
Parameters
----------
recording : BaseRecording
The recording extractor to be scaled. The recording extractor must
have gains and offsets otherwise an error will be raised.
Raises
------
AssertionError
If the recording extractor does not have scaleable traces.
"""
# To avoid a circular import
from spikeinterface.preprocessing import ScaleRecording

if not recording.has_scaleable_traces():
error_msg = "Recording must have gains and offsets set to be scaled to µV"
raise RuntimeError(error_msg)

gain = recording.get_channel_gains()
offset = recording.get_channel_offsets()

scaled_to_uV_recording = ScaleRecording(recording, gain=gain, offset=offset, dtype="float32")

# We do this so when get_traces(return_scaled=True) is called, the return is the same.
scaled_to_uV_recording.set_channel_gains(gains=1.0)
scaled_to_uV_recording.set_channel_offsets(offsets=0.0)

return scaled_to_uV_recording
70 changes: 70 additions & 0 deletions src/spikeinterface/preprocessing/tests/test_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
import numpy as np
from spikeinterface.core.testing_tools import generate_recording
from spikeinterface.preprocessing import scale_to_uV, CenterRecording


def test_scale_to_uV():
# Create a sample recording extractor with fake gains and offsets
num_channels = 4
sampling_frequency = 30_000.0
durations = [1.0, 1.0] # seconds
recording = generate_recording(
num_channels=num_channels,
durations=durations,
sampling_frequency=sampling_frequency,
)

rng = np.random.default_rng(0)
gains = rng.random(size=(num_channels)).astype(np.float32)
offsets = rng.random(size=(num_channels)).astype(np.float32)
recording.set_channel_gains(gains)
recording.set_channel_offsets(offsets)

# Apply the preprocessor
scaled_recording = scale_to_uV(recording=recording)

# Check if the traces are indeed scaled
expected_traces = recording.get_traces(return_scaled=True, segment_index=0)
scaled_traces = scaled_recording.get_traces(segment_index=0)

np.testing.assert_allclose(scaled_traces, expected_traces)

# Test for the error when recording doesn't have scaleable traces
recording.set_channel_gains(None) # Remove gains to make traces unscaleable
with pytest.raises(RuntimeError):
scale_to_uV(recording)


def test_scaling_in_preprocessing_chain():

# Create a sample recording extractor with fake gains and offsets
num_channels = 4
sampling_frequency = 30_000.0
durations = [1.0] # seconds
recording = generate_recording(
num_channels=num_channels,
durations=durations,
sampling_frequency=sampling_frequency,
)

rng = np.random.default_rng(0)
gains = rng.random(size=(num_channels)).astype(np.float32)
offsets = rng.random(size=(num_channels)).astype(np.float32)

recording.set_channel_gains(gains)
recording.set_channel_offsets(offsets)

centered_recording = CenterRecording(scale_to_uV(recording=recording))
traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True)

# Chain preprocessors
centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording))
traces_scaled_with_preprocessor = centered_recording_scaled.get_traces()

np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor)

# Test if the scaling is not done twice
traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True)

np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument)

0 comments on commit 5a7d890

Please sign in to comment.