Skip to content

Commit

Permalink
Merge pull request #2389 from DradeAW/patch-1
Browse files Browse the repository at this point in the history
Add `margin_sd` to gaussian filtering
  • Loading branch information
samuelgarcia authored Jan 9, 2024
2 parents 301b1ce + 6730f1e commit 03126a3
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/spikeinterface/preprocessing/filter_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class GaussianBandpassFilterRecording(BasePreprocessor):
The lower frequency cutoff for the bandpass filter.
freq_max: float
The higher frequency cutoff for the bandpass filter.
margin_sd: float, default: 5.0
The number of standard deviation to take for margins.
Returns
-------
Expand All @@ -32,19 +34,23 @@ class GaussianBandpassFilterRecording(BasePreprocessor):

name = "gaussian_bandpass_filter"

def __init__(self, recording: BaseRecording, freq_min: float = 300.0, freq_max: float = 5000.0):
def __init__(
self, recording: BaseRecording, freq_min: float = 300.0, freq_max: float = 5000.0, margin_sd: float = 5.0
):
sf = recording.sampling_frequency
BasePreprocessor.__init__(self, recording)
self.annotate(is_filtered=True)

for parent_segment in recording._recording_segments:
self.add_recording_segment(GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max))
self.add_recording_segment(GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max, margin_sd))

self._kwargs = {"recording": recording, "freq_min": freq_min, "freq_max": freq_max}


class GaussianFilterRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment: BaseRecordingSegment, freq_min: float, freq_max: float):
def __init__(
self, parent_recording_segment: BaseRecordingSegment, freq_min: float, freq_max: float, margin_sd: float = 5.0
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)

self.freq_min = freq_min
Expand All @@ -54,7 +60,7 @@ def __init__(self, parent_recording_segment: BaseRecordingSegment, freq_min: flo
sf = parent_recording_segment.sampling_frequency
low_sigma = sf / (2 * np.pi * freq_min)
high_sigma = sf / (2 * np.pi * freq_max)
self.margin = int(max(low_sigma, high_sigma) * 6.0 + 1)
self.margin = 1 + int(max(low_sigma, high_sigma) * margin_sd)

def get_traces(
self,
Expand Down

0 comments on commit 03126a3

Please sign in to comment.