Skip to content

Commit

Permalink
Switched all filtering to Gaussian filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
DradeAW committed Dec 18, 2023
1 parent dfa64bc commit 5c22d8d
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 72 deletions.
14 changes: 5 additions & 9 deletions docs/source/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ This module will label units as belonging to a certain category if they meet som

This module takes as a key the name of the category, and as a value a dictionary containing the criteria. Each criterion return a value for each unit, and a minimum and/or maximum can be set.

For some parameters, you can specify the parameters for :code:`wvf_extraction` and :code:`filter` (parameters given to the SpikeInterface method :code:`core.extract_waveforms` and :code:`preprocessing.filter` respectively).
For some parameters, you can specify the parameters for :code:`wvf_extraction` (parameters given to the SpikeInterface method :code:`core.extract_waveforms`) and :code:`filter` (a list :code:`[freq_min, freq_max]` for Gaussian bandpass filtering).

- :code:`firing_rate`: returns the mean firing rate of the unit (in Hz).
- :code:`contamination`: returns the estimated contamination of the unit (between 0 and 1; 0 being pure). The :code:`refractory_period = [censored_period, refractory_period]` has to be set (in ms).
- :code:`amplitude`: returns the mean amplitude of the unit's template (in µV if the recording object can be scaled to µV). Optional parameters can be set to the function :code:`spikeinterface.core.get_template_extremum_amplitude`.
- :code:`SNR`: returns the signal-to-noise ratio of the unit. Optional parameters can be set to the function :code:`spikeinterface.qualitymetrics.compute_snrs`
- :code:`amplitude_std`: returns the standard deviation in the spike amplitudes for the unit. Optional parameters can be set to the function :code:`spikeinterface.postprocessing.compute_spike_amplitudes`
- :code:`sd_ratio`: returns the standard deviation in the spike amplitudes for the unit divided by the standard deviation on the same channel. Optional parameters can be set to the function :code:`spikeinterface.postprocessing.compute_spike_amplitudes` as a dictionary with key :code:`spike_amplitudes_kwargs`. Optional parameters can be set to the function :code:`spikeinterface.qualitymetrics.compute_sd_ratio` as a dictionary with key :code:`sd_ratio_kwargs`.
- :code:`ISI_portion`: Returns the fraction (between 0 and 1) of inter-spike intervals inside a time range. the :code:`range = [min_t, max_t]` has to be set (in ms).


Expand Down Expand Up @@ -58,11 +58,7 @@ Here is an example for categorizing complex-spikes from more regular spikes (cer
"ms_after": 1.0,
"max_spikes_per_unit": 500
},
"filter": { // Parameters for spikeinterface.preprocessing.filter
"band": [150, 7_000],
"filter_order": 2,
"ftype": "bessel"
}
"filter": [150, 7_000] // Gaussian bandpass filter with cutoffs at 150 and 7,000 Hz.
}
}
}
Expand Down Expand Up @@ -140,8 +136,8 @@ Example of units removal
"firing_rate": {
"min": 1.0
},
"amplitude_std": {
"max": 80.0
"sd_ratio": {
"max": 2.0
}
}
Expand Down
18 changes: 3 additions & 15 deletions params_examples/params_cerebellar_cortex[beta].json
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,7 @@
},
"spikes": {
"amplitude": {
"filter": {
"band": [450, 9000],
"filter_order": 2,
"ftype": "bessel"
},
"filter": [450, 9000],
"min": 20,
"wvf_extraction": {
"ms_before": 2.0,
Expand Down Expand Up @@ -175,11 +171,7 @@
"ms_before": 1.5,
"ms_after": 3.5,
"max_spikes_per_unit": 2000,
"filter": {
"band": [150, 3000],
"filter_order": 2,
"ftype": "bessel"
}
"filter": [150, 3000]
},
"auto_merge_params": {
"minimum_spikes": 300,
Expand All @@ -199,11 +191,7 @@
"ms_before": 1.2,
"ms_after": 1.2,
"max_spikes_per_unit": 2000,
"filter": {
"band": [300, 6000],
"filter_order": 2,
"ftype": "bessel"
}
"filter": [300, 6000]
},
"auto_merge_params": {
"bin_ms": 0.05,
Expand Down
35 changes: 17 additions & 18 deletions src/lussac/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def tmp_folder(self) -> pathlib.Path:
def run(self, params: dict[str, Any]) -> si.BaseSorting:
...

def extract_waveforms(self, sorting: si.BaseSorting | None = None, sub_folder: str | None = None, filter: dict[str, Any] | None = None, **params) -> si.WaveformExtractor:
def extract_waveforms(self, sorting: si.BaseSorting | None = None, sub_folder: str | None = None, filter: list[float, float] | None = None, **params) -> si.WaveformExtractor:
"""
Creates the WaveformExtractor object and returns it.
Expand All @@ -168,8 +168,8 @@ def extract_waveforms(self, sorting: si.BaseSorting | None = None, sub_folder: s
The sub-folder where to save the waveforms.
@param params
The parameters for the waveform extractor.
@param filter: dict | None
The filter to apply to the recording.
@param filter: list[float, float] | None
The cutoff frequencies for the Gaussian bandpass filter to apply to the recording.
@return wvf_extractor: WaveformExtractor
The waveform extractor object.
"""
Expand All @@ -180,7 +180,7 @@ def extract_waveforms(self, sorting: si.BaseSorting | None = None, sub_folder: s

recording = self.recording
if filter is not None:
recording = spre.filter(recording, **filter)
recording = spre.gaussian_bandpass_filter(recording, *filter)

sorting = self.sorting if sorting is None else sorting
return si.extract_waveforms(recording, sorting, folder_path, allow_unfiltered=True, **params)
Expand Down Expand Up @@ -250,22 +250,22 @@ def get_units_attribute(self, attribute: str, params: dict) -> dict:
'firing_rate': {},
'contamination': {},
'amplitude': {
'wvf_extraction': {'ms_before': 1.0, 'ms_after': 1.0, 'max_spikes_per_unit': 500},
'peak_sign': "both",
'mode': "extremum",
'wvf_extraction': {'ms_before': 1.0, 'ms_after': 1.0, 'max_spikes_per_unit': 500},
'filter': {'band': [100, 9_000], 'filter_order': 2, 'ftype': "bessel"}
'filter': [100, 9_000]
},
'SNR': {
'wvf_extraction': {'ms_before': 1.0, 'ms_after': 1.0, 'max_spikes_per_unit': 500},
'peak_sign': "both",
'mode': "extremum",
'wvf_extraction': {'ms_before': 1.0, 'ms_after': 1.0, 'max_spikes_per_unit': 500},
'filter': {'band': [100, 9_000], 'filter_order': 2, 'ftype': "bessel"}
'filter': [100, 9_000]
},
'amplitude_std': {
'peak_sign': "both",
'return_scaled': True,
'sd_ratio': {
'wvf_extraction': {'ms_before': 1.0, 'ms_after': 1.0, 'max_spikes_per_unit': 500},
'filter': {'band': [100, 9_000], 'filter_order': 2, 'ftype': "bessel"}
'spike_amplitudes_kwargs': {'peak_sign': "both"},
'sd_ratio_kwargs': {},
'filter': [100, 9_000]
},
'ISI_portion': {}
}
Expand All @@ -277,7 +277,7 @@ def get_units_attribute(self, attribute: str, params: dict) -> dict:
recording = self.data.recording
sorting = self.sorting
if 'filter' in params:
recording = spre.filter(recording, **params['filter'])
recording = spre.gaussian_bandpass_filter(recording, *params['filter'])

wvf_extractor = self.extract_waveforms(sub_folder=attribute, **params['wvf_extraction']) if 'wvf_extraction' in params \
else si.WaveformExtractor(recording, sorting, allow_unfiltered=True)
Expand All @@ -303,11 +303,10 @@ def get_units_attribute(self, attribute: str, params: dict) -> dict:
SNRs = sqm.compute_snrs(wvf_extractor, **params)
return SNRs

case "amplitude_std": # Returns the standard deviation of the amplitude of spikes.
params = utils.filter_kwargs(params, spost.compute_spike_amplitudes)
amplitudes = spost.compute_spike_amplitudes(wvf_extractor, outputs='by_unit', **params)[0]
std_amplitudes = {unit_id: np.std(amp) for unit_id, amp in amplitudes.items()}
return std_amplitudes
case "sd_ratio": # Returns the standard deviation of the amplitude of spikes divided by the standard deviation on the same channel.
_ = spost.compute_spike_amplitudes(wvf_extractor, **params['spike_amplitudes_kwargs'])
sd_ratio = sqm.compute_sd_ratio(wvf_extractor, **params['sd_ratio_kwargs'])
return sd_ratio

case "ISI_portion": # Returns the portion of consecutive spikes that are between a certain range (in ms).
low, high = np.array(params['range']) * recording.sampling_frequency * 1e-3
Expand Down
6 changes: 1 addition & 5 deletions src/lussac/modules/merge_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ def default_params(self) -> dict[str, Any]:
'ms_after': 1.5,
'max_spikes_per_unit': 2_000,
'sparse': False,
'filter': {
'band': [100, 9000],
'filter_order': 2,
'ftype': 'bessel'
}
'filter': [100, 9000]
},
'auto_merge_params': {
'bin_ms': 0.05,
Expand Down
22 changes: 8 additions & 14 deletions tests/core/test_mono_sorting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
"ms_after": 1.0,
"max_spikes_per_unit": 10
},
"filter": {
"band": [200, 5000]
},
"filter": [200, 5000],
"min": 20
},
"SNR": {
Expand All @@ -29,20 +27,16 @@
"ms_after": 1.0,
"max_spikes_per_unit": 10
},
"filter": {
"band": [300, 6000]
},
"filter": [300, 6000],
"min": 1.2
},
"amplitude_std": {
"sd_ratio": {
"wvf_extraction": {
"ms_before": 1.0,
"ms_after": 1.0,
},
"filter": {
"band": [200, 5000]
},
"max": 140
"filter": [200, 5000],
"max": 2.0
}
}

Expand Down Expand Up @@ -142,9 +136,9 @@ def test_get_units_attribute(mono_sorting_data: MonoSortingData) -> None:
assert isinstance(SNRs, np.ndarray)
assert SNRs.shape == (num_units, )

amplitude_std = module.get_units_attribute_arr("amplitude_std", params['amplitude_std'])
assert isinstance(amplitude_std, np.ndarray)
assert amplitude_std.shape == (num_units, )
sd_ratio = module.get_units_attribute_arr("sd_ratio", params['sd_ratio'])
assert isinstance(sd_ratio, np.ndarray)
assert sd_ratio.shape == (num_units, )

with pytest.raises(ValueError):
module.get_units_attribute("test", {})
Expand Down
16 changes: 5 additions & 11 deletions tests/modules/test_remove_bad_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
"ms_after": 1.0,
"max_spikes_per_unit": 10
},
"filter": {
"band": [200, 5000]
},
"filter": [200, 5000],
"min": 20
},
"SNR": {
Expand All @@ -28,20 +26,16 @@
"ms_after": 1.0,
"max_spikes_per_unit": 10
},
"filter": {
"band": [300, 6000]
},
"filter": [300, 6000],
"min": 1.2
},
"amplitude_std": {
"sd_ratio": {
"wvf_extraction": {
"ms_before": 1.0,
"ms_after": 1.0,
},
"filter": {
"band": [200, 5000]
},
"max": 140
"filter": [200, 5000],
"max": 2.0
}
}

Expand Down

0 comments on commit 5c22d8d

Please sign in to comment.