From 4ecc16482ae8850be2a3f8fdebfade0d28eda7f2 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jun 2024 18:04:26 -0600 Subject: [PATCH 1/9] add scale to microvolts --- src/spikeinterface/preprocessing/preprocessinglist.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 1b28be9752..0c2ca0cb9a 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -24,6 +24,8 @@ CenterRecording, center, ) +from .scale import ScaleTouV, scale_to_uV + from .whiten import WhitenRecording, whiten, compute_whitening_matrix from .rectify import RectifyRecording, rectify from .clip import BlankSaturationRecording, blank_staturation, ClipRecording, clip @@ -54,6 +56,7 @@ ScaleRecording, CenterRecording, ZScoreRecording, + ScaleTouV, # decorrelation stuff WhitenRecording, # re-reference From 3dbb8c104dfeb942bfb5ce17f65d3b3b8c1a962c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jun 2024 18:11:43 -0600 Subject: [PATCH 2/9] added untracked files, ups --- src/spikeinterface/preprocessing/scale.py | 46 +++++++++++++++++++ .../preprocessing/tests/test_scaling.py | 35 ++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 src/spikeinterface/preprocessing/scale.py create mode 100644 src/spikeinterface/preprocessing/tests/test_scaling.py diff --git a/src/spikeinterface/preprocessing/scale.py b/src/spikeinterface/preprocessing/scale.py new file mode 100644 index 0000000000..a8837010ea --- /dev/null +++ b/src/spikeinterface/preprocessing/scale.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from spikeinterface.core import BaseRecording +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor + + +class ScaleTouV(BasePreprocessor): + """ + Scale raw traces to microvolts (µV). + + This preprocessor uses the channel-specific gain and offset information + stored in the recording extractor to convert the raw traces to µV units. + + Parameters + ---------- + recording : BaseRecording + The recording extractor to be scaled. The recording extractor must + have gains and offsets otherwise an error will be raised. + + Raises + ------ + AssertionError + If the recording extractor does not have scaleable traces. + """ + + name = "scale_to_uV" + + def __init__(self, recording: BaseRecording): + assert recording.has_scaleable_traces(), "Recording must have scaleable traces" + from spikeinterface.preprocessing.normalize_scale import ScaleRecordingSegment + + dtype = recording.get_dtype() + BasePreprocessor.__init__(self, recording, dtype=dtype) + + gain = recording.get_channel_gains()[None, :] + offset = recording.get_channel_offsets()[None, :] + for parent_segment in recording._recording_segments: + rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, self._dtype) + self.add_recording_segment(rec_segment) + + self._kwargs = dict( + recording=recording, + ) + + +scale_to_uV = ScaleTouV diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py new file mode 100644 index 0000000000..39fe3d0ddc --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -0,0 +1,35 @@ +import pytest +import numpy as np +from spikeinterface.core.testing_tools import generate_recording +from spikeinterface.preprocessing import ScaleTouV # Replace 'your_module' with your actual module name + + +def test_scale_to_uv(): + # Create a sample recording extractor with fake gains and offsets + num_channels = 4 + sampling_frequency = 30_000.0 + durations = [1] # seconds + recording = generate_recording( + num_channels=num_channels, + durations=durations, + sampling_frequency=sampling_frequency, + ) + + gains = np.ones(shape=(num_channels)) + offsets = np.zeros(shape=(num_channels)) + recording.set_channel_gains(gains) # Random gains + recording.set_channel_offsets(offsets) # Random offsets + + # Apply the preprocessor + scaled_recording = ScaleTouV(recording=recording) + + # Check if the traces are indeed scaled + expected_traces = recording.get_traces(return_scaled=True) + scaled_traces = scaled_recording.get_traces() + + np.testing.assert_allclose(scaled_traces, expected_traces) + + # Test for the error when recording doesn't have scaleable traces + recording.set_channel_gains(None) # Remove gains to make traces unscaleable + with pytest.raises(AssertionError): + ScaleTouV(recording) From be59dbe4413610f0212ded90dbde63b4da234e4d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jun 2024 18:30:06 -0600 Subject: [PATCH 3/9] add more personality to the test --- src/spikeinterface/preprocessing/tests/test_scaling.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 39fe3d0ddc..6dbc66591f 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -15,10 +15,11 @@ def test_scale_to_uv(): sampling_frequency=sampling_frequency, ) - gains = np.ones(shape=(num_channels)) - offsets = np.zeros(shape=(num_channels)) - recording.set_channel_gains(gains) # Random gains - recording.set_channel_offsets(offsets) # Random offsets + rng = np.random.default_rng(0) + gains = rng.random(size=(num_channels)).astype(np.float32) + offsets = rng.random(size=(num_channels)).astype(np.float32) + recording.set_channel_gains(gains) + recording.set_channel_offsets(offsets) # Apply the preprocessor scaled_recording = ScaleTouV(recording=recording) From a98e81af11f06c0ee47338b42e3132964f558d76 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 20 Jun 2024 07:59:59 -0600 Subject: [PATCH 4/9] name changing --- src/spikeinterface/preprocessing/preprocessinglist.py | 4 ++-- src/spikeinterface/preprocessing/scale.py | 4 ++-- src/spikeinterface/preprocessing/tests/test_scaling.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 0c2ca0cb9a..7fc3bc0685 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -24,7 +24,7 @@ CenterRecording, center, ) -from .scale import ScaleTouV, scale_to_uV +from .scale import ScaleTouVRecording, scale_to_uV from .whiten import WhitenRecording, whiten, compute_whitening_matrix from .rectify import RectifyRecording, rectify @@ -56,7 +56,7 @@ ScaleRecording, CenterRecording, ZScoreRecording, - ScaleTouV, + ScaleTouVRecording, # decorrelation stuff WhitenRecording, # re-reference diff --git a/src/spikeinterface/preprocessing/scale.py b/src/spikeinterface/preprocessing/scale.py index a8837010ea..99acd49981 100644 --- a/src/spikeinterface/preprocessing/scale.py +++ b/src/spikeinterface/preprocessing/scale.py @@ -4,7 +4,7 @@ from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor -class ScaleTouV(BasePreprocessor): +class ScaleTouVRecording(BasePreprocessor): """ Scale raw traces to microvolts (µV). @@ -43,4 +43,4 @@ def __init__(self, recording: BaseRecording): ) -scale_to_uV = ScaleTouV +scale_to_uV = ScaleTouVRecording diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 6dbc66591f..e66d36c613 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -1,7 +1,7 @@ import pytest import numpy as np from spikeinterface.core.testing_tools import generate_recording -from spikeinterface.preprocessing import ScaleTouV # Replace 'your_module' with your actual module name +from spikeinterface.preprocessing import ScaleTouVRecording # Replace 'your_module' with your actual module name def test_scale_to_uv(): @@ -22,7 +22,7 @@ def test_scale_to_uv(): recording.set_channel_offsets(offsets) # Apply the preprocessor - scaled_recording = ScaleTouV(recording=recording) + scaled_recording = ScaleTouVRecording(recording=recording) # Check if the traces are indeed scaled expected_traces = recording.get_traces(return_scaled=True) @@ -33,4 +33,4 @@ def test_scale_to_uv(): # Test for the error when recording doesn't have scaleable traces recording.set_channel_gains(None) # Remove gains to make traces unscaleable with pytest.raises(AssertionError): - ScaleTouV(recording) + ScaleTouVRecording(recording) From 511ebd5d21badab3416b6fba0b7781caf6b86db1 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 27 Jun 2024 10:54:12 -0600 Subject: [PATCH 5/9] first comments --- src/spikeinterface/preprocessing/scale.py | 10 ++++++++-- src/spikeinterface/preprocessing/tests/test_scaling.py | 10 +++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/preprocessing/scale.py b/src/spikeinterface/preprocessing/scale.py index 99acd49981..03ccee757b 100644 --- a/src/spikeinterface/preprocessing/scale.py +++ b/src/spikeinterface/preprocessing/scale.py @@ -1,5 +1,7 @@ from __future__ import annotations +import numpy as np + from spikeinterface.core import BaseRecording from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor @@ -26,12 +28,16 @@ class ScaleTouVRecording(BasePreprocessor): name = "scale_to_uV" def __init__(self, recording: BaseRecording): - assert recording.has_scaleable_traces(), "Recording must have scaleable traces" + # Importing inside to avoid a circular import from spikeinterface.preprocessing.normalize_scale import ScaleRecordingSegment - dtype = recording.get_dtype() + dtype = np.dtype("float32") BasePreprocessor.__init__(self, recording, dtype=dtype) + if not recording.has_scaleable_traces(): + error_msg = "Recording must have gains and offsets set to be scaled to µV" + raise RuntimeError(error_msg) + gain = recording.get_channel_gains()[None, :] offset = recording.get_channel_offsets()[None, :] for parent_segment in recording._recording_segments: diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index e66d36c613..7079e6f6ae 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -1,14 +1,14 @@ import pytest import numpy as np from spikeinterface.core.testing_tools import generate_recording -from spikeinterface.preprocessing import ScaleTouVRecording # Replace 'your_module' with your actual module name +from spikeinterface.preprocessing import ScaleTouVRecording def test_scale_to_uv(): # Create a sample recording extractor with fake gains and offsets num_channels = 4 sampling_frequency = 30_000.0 - durations = [1] # seconds + durations = [1.0, 1.0] # seconds recording = generate_recording( num_channels=num_channels, durations=durations, @@ -25,12 +25,12 @@ def test_scale_to_uv(): scaled_recording = ScaleTouVRecording(recording=recording) # Check if the traces are indeed scaled - expected_traces = recording.get_traces(return_scaled=True) - scaled_traces = scaled_recording.get_traces() + expected_traces = recording.get_traces(return_scaled=True, segment_index=0) + scaled_traces = scaled_recording.get_traces(segment_index=0) np.testing.assert_allclose(scaled_traces, expected_traces) # Test for the error when recording doesn't have scaleable traces recording.set_channel_gains(None) # Remove gains to make traces unscaleable - with pytest.raises(AssertionError): + with pytest.raises(RuntimeError): ScaleTouVRecording(recording) From 62485d5ee5c9dbb30085b3eaad9ea07a448d2ace Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 27 Jun 2024 11:05:50 -0600 Subject: [PATCH 6/9] add failing test --- .../preprocessing/tests/test_scaling.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 7079e6f6ae..098e77caad 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -1,7 +1,7 @@ import pytest import numpy as np from spikeinterface.core.testing_tools import generate_recording -from spikeinterface.preprocessing import ScaleTouVRecording +from spikeinterface.preprocessing import ScaleTouVRecording, CenterRecording def test_scale_to_uv(): @@ -34,3 +34,37 @@ def test_scale_to_uv(): recording.set_channel_gains(None) # Remove gains to make traces unscaleable with pytest.raises(RuntimeError): ScaleTouVRecording(recording) + + +def test_scaling_in_preprocessing_chain(): + + # Create a sample recording extractor with fake gains and offsets + num_channels = 4 + sampling_frequency = 30_000.0 + durations = [1.0] # seconds + recording = generate_recording( + num_channels=num_channels, + durations=durations, + sampling_frequency=sampling_frequency, + ) + + rng = np.random.default_rng(0) + gains = rng.random(size=(num_channels)).astype(np.float32) + offsets = rng.random(size=(num_channels)).astype(np.float32) + + recording.set_channel_gains(gains) + recording.set_channel_offsets(offsets) + + centered_recording = CenterRecording(ScaleTouVRecording(recording=recording)) + traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True) + + # Chain preprocessors + centered_recording_scaled = CenterRecording(ScaleTouVRecording(recording=recording)) + traces_scaled_with_preprocessor = centered_recording_scaled.get_traces() + + np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor) + + # Test if the scaling is not done twice + traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True) + + np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument) From bf3cc4b21faf0829521af29187df0cfb4fcfd445 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 28 Jun 2024 16:08:40 -0600 Subject: [PATCH 7/9] tests are passing --- src/spikeinterface/preprocessing/scale.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/scale.py b/src/spikeinterface/preprocessing/scale.py index 03ccee757b..45ffe383d4 100644 --- a/src/spikeinterface/preprocessing/scale.py +++ b/src/spikeinterface/preprocessing/scale.py @@ -38,10 +38,13 @@ def __init__(self, recording: BaseRecording): error_msg = "Recording must have gains and offsets set to be scaled to µV" raise RuntimeError(error_msg) + self.set_channel_gains(gains=1.0) + self.set_channel_offsets(offsets=0.0) + gain = recording.get_channel_gains()[None, :] offset = recording.get_channel_offsets()[None, :] for parent_segment in recording._recording_segments: - rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, self._dtype) + rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype) self.add_recording_segment(rec_segment) self._kwargs = dict( From c6a521b8569ef062edf277a71b9aa74af0e7b1ea Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 28 Jun 2024 16:20:46 -0600 Subject: [PATCH 8/9] @alejo91 suggestion --- src/spikeinterface/preprocessing/scale.py | 38 +++++++------------ .../preprocessing/tests/test_scaling.py | 12 +++--- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/src/spikeinterface/preprocessing/scale.py b/src/spikeinterface/preprocessing/scale.py index 45ffe383d4..bc77577ce0 100644 --- a/src/spikeinterface/preprocessing/scale.py +++ b/src/spikeinterface/preprocessing/scale.py @@ -6,7 +6,7 @@ from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor -class ScaleTouVRecording(BasePreprocessor): +def scale_to_uV(recording: BasePreprocessor) -> BasePreprocessor: """ Scale raw traces to microvolts (µV). @@ -24,32 +24,20 @@ class ScaleTouVRecording(BasePreprocessor): AssertionError If the recording extractor does not have scaleable traces. """ + # To avoid a circular import + from spikeinterface.preprocessing import ScaleRecording - name = "scale_to_uV" + if not recording.has_scaleable_traces(): + error_msg = "Recording must have gains and offsets set to be scaled to µV" + raise RuntimeError(error_msg) - def __init__(self, recording: BaseRecording): - # Importing inside to avoid a circular import - from spikeinterface.preprocessing.normalize_scale import ScaleRecordingSegment + gain = recording.get_channel_gains() + offset = recording.get_channel_offsets() - dtype = np.dtype("float32") - BasePreprocessor.__init__(self, recording, dtype=dtype) + scaled_to_uV_recording = ScaleRecording(recording, gain=gain, offset=offset, dtype="float32") - if not recording.has_scaleable_traces(): - error_msg = "Recording must have gains and offsets set to be scaled to µV" - raise RuntimeError(error_msg) + # We do this so when get_traces(return_scaled=True) is called, the return is the same. + scaled_to_uV_recording.set_channel_gains(gains=1.0) + scaled_to_uV_recording.set_channel_offsets(offsets=0.0) - self.set_channel_gains(gains=1.0) - self.set_channel_offsets(offsets=0.0) - - gain = recording.get_channel_gains()[None, :] - offset = recording.get_channel_offsets()[None, :] - for parent_segment in recording._recording_segments: - rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype) - self.add_recording_segment(rec_segment) - - self._kwargs = dict( - recording=recording, - ) - - -scale_to_uV = ScaleTouVRecording + return scaled_to_uV_recording diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 098e77caad..321d7c9df2 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -1,10 +1,10 @@ import pytest import numpy as np from spikeinterface.core.testing_tools import generate_recording -from spikeinterface.preprocessing import ScaleTouVRecording, CenterRecording +from spikeinterface.preprocessing import scale_to_uV, CenterRecording -def test_scale_to_uv(): +def test_scale_to_uV(): # Create a sample recording extractor with fake gains and offsets num_channels = 4 sampling_frequency = 30_000.0 @@ -22,7 +22,7 @@ def test_scale_to_uv(): recording.set_channel_offsets(offsets) # Apply the preprocessor - scaled_recording = ScaleTouVRecording(recording=recording) + scaled_recording = scale_to_uV(recording=recording) # Check if the traces are indeed scaled expected_traces = recording.get_traces(return_scaled=True, segment_index=0) @@ -33,7 +33,7 @@ def test_scale_to_uv(): # Test for the error when recording doesn't have scaleable traces recording.set_channel_gains(None) # Remove gains to make traces unscaleable with pytest.raises(RuntimeError): - ScaleTouVRecording(recording) + scale_to_uV(recording) def test_scaling_in_preprocessing_chain(): @@ -55,11 +55,11 @@ def test_scaling_in_preprocessing_chain(): recording.set_channel_gains(gains) recording.set_channel_offsets(offsets) - centered_recording = CenterRecording(ScaleTouVRecording(recording=recording)) + centered_recording = CenterRecording(scale_to_uV(recording=recording)) traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True) # Chain preprocessors - centered_recording_scaled = CenterRecording(ScaleTouVRecording(recording=recording)) + centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording)) traces_scaled_with_preprocessor = centered_recording_scaled.get_traces() np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor) From 0b5c6358c674b4f5d3af3295bfbdb3bbfda10c13 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 28 Jun 2024 16:23:56 -0600 Subject: [PATCH 9/9] fix imports --- src/spikeinterface/preprocessing/preprocessinglist.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 7fc3bc0685..8f3729b49b 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -24,7 +24,7 @@ CenterRecording, center, ) -from .scale import ScaleTouVRecording, scale_to_uV +from .scale import scale_to_uV from .whiten import WhitenRecording, whiten, compute_whitening_matrix from .rectify import RectifyRecording, rectify @@ -56,7 +56,6 @@ ScaleRecording, CenterRecording, ZScoreRecording, - ScaleTouVRecording, # decorrelation stuff WhitenRecording, # re-reference