-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3053 from h-mayorquin/add_scale_to_uV_preprocessing
Add `scale_to_uV` preprocessing
- Loading branch information
Showing
3 changed files
with
115 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
from spikeinterface.core import BaseRecording | ||
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor | ||
|
||
|
||
def scale_to_uV(recording: BasePreprocessor) -> BasePreprocessor: | ||
""" | ||
Scale raw traces to microvolts (µV). | ||
This preprocessor uses the channel-specific gain and offset information | ||
stored in the recording extractor to convert the raw traces to µV units. | ||
Parameters | ||
---------- | ||
recording : BaseRecording | ||
The recording extractor to be scaled. The recording extractor must | ||
have gains and offsets otherwise an error will be raised. | ||
Raises | ||
------ | ||
AssertionError | ||
If the recording extractor does not have scaleable traces. | ||
""" | ||
# To avoid a circular import | ||
from spikeinterface.preprocessing import ScaleRecording | ||
|
||
if not recording.has_scaleable_traces(): | ||
error_msg = "Recording must have gains and offsets set to be scaled to µV" | ||
raise RuntimeError(error_msg) | ||
|
||
gain = recording.get_channel_gains() | ||
offset = recording.get_channel_offsets() | ||
|
||
scaled_to_uV_recording = ScaleRecording(recording, gain=gain, offset=offset, dtype="float32") | ||
|
||
# We do this so when get_traces(return_scaled=True) is called, the return is the same. | ||
scaled_to_uV_recording.set_channel_gains(gains=1.0) | ||
scaled_to_uV_recording.set_channel_offsets(offsets=0.0) | ||
|
||
return scaled_to_uV_recording |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pytest | ||
import numpy as np | ||
from spikeinterface.core.testing_tools import generate_recording | ||
from spikeinterface.preprocessing import scale_to_uV, CenterRecording | ||
|
||
|
||
def test_scale_to_uV(): | ||
# Create a sample recording extractor with fake gains and offsets | ||
num_channels = 4 | ||
sampling_frequency = 30_000.0 | ||
durations = [1.0, 1.0] # seconds | ||
recording = generate_recording( | ||
num_channels=num_channels, | ||
durations=durations, | ||
sampling_frequency=sampling_frequency, | ||
) | ||
|
||
rng = np.random.default_rng(0) | ||
gains = rng.random(size=(num_channels)).astype(np.float32) | ||
offsets = rng.random(size=(num_channels)).astype(np.float32) | ||
recording.set_channel_gains(gains) | ||
recording.set_channel_offsets(offsets) | ||
|
||
# Apply the preprocessor | ||
scaled_recording = scale_to_uV(recording=recording) | ||
|
||
# Check if the traces are indeed scaled | ||
expected_traces = recording.get_traces(return_scaled=True, segment_index=0) | ||
scaled_traces = scaled_recording.get_traces(segment_index=0) | ||
|
||
np.testing.assert_allclose(scaled_traces, expected_traces) | ||
|
||
# Test for the error when recording doesn't have scaleable traces | ||
recording.set_channel_gains(None) # Remove gains to make traces unscaleable | ||
with pytest.raises(RuntimeError): | ||
scale_to_uV(recording) | ||
|
||
|
||
def test_scaling_in_preprocessing_chain(): | ||
|
||
# Create a sample recording extractor with fake gains and offsets | ||
num_channels = 4 | ||
sampling_frequency = 30_000.0 | ||
durations = [1.0] # seconds | ||
recording = generate_recording( | ||
num_channels=num_channels, | ||
durations=durations, | ||
sampling_frequency=sampling_frequency, | ||
) | ||
|
||
rng = np.random.default_rng(0) | ||
gains = rng.random(size=(num_channels)).astype(np.float32) | ||
offsets = rng.random(size=(num_channels)).astype(np.float32) | ||
|
||
recording.set_channel_gains(gains) | ||
recording.set_channel_offsets(offsets) | ||
|
||
centered_recording = CenterRecording(scale_to_uV(recording=recording)) | ||
traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True) | ||
|
||
# Chain preprocessors | ||
centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording)) | ||
traces_scaled_with_preprocessor = centered_recording_scaled.get_traces() | ||
|
||
np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor) | ||
|
||
# Test if the scaling is not done twice | ||
traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True) | ||
|
||
np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument) |