From d1d65f6ca6338ac2dd8d6f9c99ee657f0db76d21 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 11:58:23 +0100 Subject: [PATCH 1/6] estimate_sparsity arg ordering --- src/spikeinterface/core/sortinganalyzer.py | 2 +- src/spikeinterface/core/sparsity.py | 6 +++--- src/spikeinterface/core/tests/test_sparsity.py | 4 ++-- .../postprocessing/tests/common_extension_tests.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 53e060262b..62b7f9e7c0 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -127,7 +127,7 @@ def create_sorting_analyzer( recording.channel_ids, sparsity.channel_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" elif sparse: - sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs) + sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs) else: sparsity = None diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index cefd7bd950..1cd7822f99 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -539,8 +539,8 @@ def compute_sparsity( def estimate_sparsity( - recording: BaseRecording, sorting: BaseSorting, + recording: BaseRecording, num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, @@ -563,10 +563,10 @@ def estimate_sparsity( Parameters ---------- - recording: BaseRecording - The recording sorting: BaseSorting The sorting + recording: BaseRecording + The recording num_spikes_for_sparsity: int, default: 100 How many spikes per units to compute the sparsity ms_before: float, default: 1.0 diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 98d033d8ea..a192d90502 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -166,8 +166,8 @@ def test_estimate_sparsity(): # small radius should give a very sparse = one channel per unit sparsity = estimate_sparsity( - recording, sorting, + recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, @@ -182,8 +182,8 @@ def test_estimate_sparsity(): # best_channel : the mask should exactly 3 channels per units sparsity = estimate_sparsity( - recording, sorting, + recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index bf462a9466..8c46fa5e24 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -79,7 +79,7 @@ class AnalyzerExtensionCommonTestSuite: def setUpClass(cls): cls.recording, cls.sorting = get_dataset() # sparsity is computed once for all cases to save processing time and force a small radius - cls.sparsity = estimate_sparsity(cls.recording, cls.sorting, method="radius", radius_um=20) + cls.sparsity = estimate_sparsity(cls.sorting, cls.recording, method="radius", radius_um=20) @property def extension_name(self): From 61060781eef87597461241aec077aac27baff69b Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:15:14 +0100 Subject: [PATCH 2/6] SpikeRetriever arg switch --- src/spikeinterface/core/node_pipeline.py | 16 +-- .../core/tests/test_node_pipeline.py | 4 +- .../tests/test_train_manual_curation.py | 120 ++++++++++++++++++ .../postprocessing/amplitude_scalings.py | 2 +- .../postprocessing/spike_amplitudes.py | 2 +- .../postprocessing/spike_locations.py | 2 +- 6 files changed, 133 insertions(+), 13 deletions(-) create mode 100644 src/spikeinterface/curation/tests/test_train_manual_curation.py diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1c0107d235..0722ede23f 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -152,29 +152,29 @@ class SpikeRetriever(PeakSource): * compute_spike_amplitudes() * compute_principal_components() + sorting : BaseSorting + The sorting object. recording : BaseRecording The recording object. - sorting: BaseSorting - The sorting object. - channel_from_template: bool, default: True + channel_from_template : bool, default: True If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. If False, the max channel is computed for each spike given a radius around the template max channel. - extremum_channel_inds: dict of int | None, default: None + extremum_channel_inds : dict of int | None, default: None The extremum channel index dict given from template. - radius_um: float, default: 50 + radius_um : float, default: 50 The radius to find the real max channel. Used only when channel_from_template=False - peak_sign: "neg" | "pos", default: "neg" + peak_sign : "neg" | "pos", default: "neg" Peak sign to find the max channel. Used only when channel_from_template=False - include_spikes_in_margin: bool, default False + include_spikes_in_margin : bool, default False If not None then spikes in margin are added and an extra filed in dtype is added """ def __init__( self, - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=None, radius_um=50, diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 03acc9fed1..8d788acbad 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -87,12 +87,12 @@ def test_run_node_pipeline(cache_folder_creation): peak_retriever = PeakRetriever(recording, peaks) # channel index is from template spike_retriever_T = SpikeRetriever( - recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) # channel index is per spike spike_retriever_S = SpikeRetriever( - recording, sorting, + recording, channel_from_template=False, extremum_channel_inds=extremum_channel_inds, radius_um=50, diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py new file mode 100644 index 0000000000..f0f9ff4d75 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -0,0 +1,120 @@ +import pytest +import pandas as pd +import os +import shutil + +from spikeinterface.curation.train_manual_curation import CurationModelTrainer, Objective, train_model + +# Sample data for testing +data = { + 'num_spikes': [1, 2, 3, 4, 5, 6], + 'firing_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'presence_ratio': [0.9, 0.8, 0.7, 0.6, 0.5, 0.4], + 'isi_violations_ratio': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06], + 'amplitude_cutoff': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'amplitude_median': [0.2, 0.3, 0.4, 0.5, 0.6, 0.7], + 'amplitude_cv_median': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'amplitude_cv_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_2': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_4': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_8': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'firing_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_ptp': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_std': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_mad': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'isolation_distance': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'l_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'd_prime': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'silhouette': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'nn_hit_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'nn_miss_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'peak_to_valley': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'peak_trough_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'half_width': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'repolarization_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'recovery_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'num_positive_peaks': [1, 2, 3, 4, 5, 6], + 'num_negative_peaks': [1, 2, 3, 4, 5, 6], + 'velocity_above': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'velocity_below': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'exp_decay': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'spread': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'is_noise': [0, 1, 0, 1, 0, 1], + 'is_sua': [1, 0, 1, 0, 1, 0], + 'majority_vote': ['good', 'bad', 'good', 'bad', 'good', 'bad'] +} + +df = pd.DataFrame(data) + +# Test initialization +def test_initialization(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + assert trainer.output_folder == '/tmp' + assert trainer.curator_column == 'num_spikes' + assert trainer.imputation_strategies is not None + assert trainer.scaling_techniques is not None + +# Test load_data_file +def test_load_data_file(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + df.to_csv('/tmp/test.csv', index=False) + trainer.load_data_file('/tmp/test.csv') + assert trainer.testing_metrics is not None + assert 0 in trainer.testing_metrics + +# Test process_test_data_for_classification +def test_process_test_data_for_classification(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + trainer.testing_metrics = {0: df} + trainer.process_test_data_for_classification() + assert trainer.noise_test is not None + assert trainer.sua_mua_test is not None + +# Test apply_scaling_imputation +def test_apply_scaling_imputation(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + X_train = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) + X_val = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) + y_train = df['is_noise'] + y_val = df['is_noise'] + result = trainer.apply_scaling_imputation('median', trainer.scaling_techniques[0][1], X_train, X_val, y_train, y_val) + assert result is not None + +# Test get_classifier_search_space +def test_get_classifier_search_space(): + from sklearn.linear_model import LogisticRegression + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + model, param_space = trainer.get_classifier_search_space(LogisticRegression) + assert model is not None + assert param_space is not None + +# Test Objective Enum +def test_objective_enum(): + assert Objective.Noise == Objective(1) + assert Objective.SUA == Objective(2) + assert str(Objective.Noise) == "Objective.Noise" + assert str(Objective.SUA) == "Objective.SUA" + +# Test train_model function +def test_train_model(monkeypatch): + output_folder = '/tmp/output' + os.makedirs(output_folder, exist_ok=True) + df.to_csv('/tmp/metrics.csv', index=False) + + def mock_load_and_preprocess_full(self, path): + self.testing_metrics = {0: df} + self.process_test_data_for_classification() + + monkeypatch.setattr(CurationModelTrainer, 'load_and_preprocess_full', mock_load_and_preprocess_full) + + trainer = train_model('/tmp/metrics.csv', output_folder, 'is_noise') + assert trainer is not None + assert trainer.testing_metrics is not None + assert 0 in trainer.testing_metrics + +# Clean up temporary files +@pytest.fixture(scope="module", autouse=True) +def cleanup(request): + def remove_tmp(): + shutil.rmtree('/tmp', ignore_errors=True) + request.addfinalizer(remove_tmp) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 2e544d086b..8ff9cc5666 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -170,8 +170,8 @@ def _get_pipeline_nodes(self): sparsity_mask = sparsity.mask spike_retriever_node = SpikeRetriever( - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices, include_spikes_in_margin=True, diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index aebfd1fd78..72cbcb651f 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -95,7 +95,7 @@ def _get_pipeline_nodes(self): peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) spike_retriever_node = SpikeRetriever( - recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channels_indices + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices ) spike_amplitudes_node = SpikeAmplitudeNode( recording, diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 52a91342b6..23301292e5 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -103,8 +103,8 @@ def _get_pipeline_nodes(self): ) retriever = SpikeRetriever( - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices, ) From 722c313382b6ac225a2c9119c676bc1bcab6e480 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:17:43 +0100 Subject: [PATCH 3/6] has_exceeding_spikes arg switch --- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/frameslicesorting.py | 2 +- src/spikeinterface/core/waveform_tools.py | 2 +- src/spikeinterface/curation/remove_excess_spikes.py | 2 +- .../curation/tests/test_remove_excess_spikes.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index fd68df9dda..d9a567dedf 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -197,7 +197,7 @@ def register_recording(self, recording, check_spike_frames=True): self.get_num_segments() == recording.get_num_segments() ), "The recording has a different number of segments than the sorting!" if check_spike_frames: - if has_exceeding_spikes(recording, self): + if has_exceeding_spikes(self, recording): warnings.warn( "Some spikes exceed the recording's duration! " "Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` " diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index ffd8af5fd8..f3ec449ab0 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -54,7 +54,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike assert ( start_frame <= parent_n_samples ), "`start_frame` should be smaller than the sortings' total number of samples." - if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): + if check_spike_frames and has_exceeding_spikes(parent_sorting, parent_sorting._recording): raise ValueError( "The sorting object has spikes whose times go beyond the recording duration." "This could indicate a bug in the sorter. " diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index befc49d034..4543074872 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -679,7 +679,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None return waveforms_by_units -def has_exceeding_spikes(recording, sorting) -> bool: +def has_exceeding_spikes(sorting, recording) -> bool: """ Check if the sorting objects has spikes exceeding the recording number of samples, for all segments diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 0ae7a59fc6..d1d6b7f3cb 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -102,7 +102,7 @@ def remove_excess_spikes(sorting, recording): sorting_without_excess_spikes : Sorting The sorting without any excess spikes. """ - if has_exceeding_spikes(recording=recording, sorting=sorting): + if has_exceeding_spikes(sorting=sorting, recording=recording): return RemoveExcessSpikesSorting(sorting=sorting, recording=recording) else: return sorting diff --git a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py index 69edbaba4c..141cc4c34e 100644 --- a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py +++ b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py @@ -39,10 +39,10 @@ def test_remove_excess_spikes(): labels.append(labels_segment) sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency=sampling_frequency) - assert has_exceeding_spikes(recording, sorting) + assert has_exceeding_spikes(sorting, recording) sorting_corrected = remove_excess_spikes(sorting, recording) - assert not has_exceeding_spikes(recording, sorting_corrected) + assert not has_exceeding_spikes(sorting_corrected, recording) for u in sorting.unit_ids: for segment_index in range(sorting.get_num_segments()): From d0968c4c941e290488848d14c6881c7a2cdf9c8c Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:19:24 +0100 Subject: [PATCH 4/6] removed accidental commit --- .../tests/test_train_manual_curation.py | 120 ------------------ 1 file changed, 120 deletions(-) delete mode 100644 src/spikeinterface/curation/tests/test_train_manual_curation.py diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py deleted file mode 100644 index f0f9ff4d75..0000000000 --- a/src/spikeinterface/curation/tests/test_train_manual_curation.py +++ /dev/null @@ -1,120 +0,0 @@ -import pytest -import pandas as pd -import os -import shutil - -from spikeinterface.curation.train_manual_curation import CurationModelTrainer, Objective, train_model - -# Sample data for testing -data = { - 'num_spikes': [1, 2, 3, 4, 5, 6], - 'firing_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'presence_ratio': [0.9, 0.8, 0.7, 0.6, 0.5, 0.4], - 'isi_violations_ratio': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06], - 'amplitude_cutoff': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'amplitude_median': [0.2, 0.3, 0.4, 0.5, 0.6, 0.7], - 'amplitude_cv_median': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'amplitude_cv_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_2': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_4': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_8': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'firing_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_ptp': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_std': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_mad': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'isolation_distance': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'l_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'd_prime': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'silhouette': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'nn_hit_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'nn_miss_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'peak_to_valley': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'peak_trough_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'half_width': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'repolarization_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'recovery_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'num_positive_peaks': [1, 2, 3, 4, 5, 6], - 'num_negative_peaks': [1, 2, 3, 4, 5, 6], - 'velocity_above': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'velocity_below': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'exp_decay': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'spread': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'is_noise': [0, 1, 0, 1, 0, 1], - 'is_sua': [1, 0, 1, 0, 1, 0], - 'majority_vote': ['good', 'bad', 'good', 'bad', 'good', 'bad'] -} - -df = pd.DataFrame(data) - -# Test initialization -def test_initialization(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - assert trainer.output_folder == '/tmp' - assert trainer.curator_column == 'num_spikes' - assert trainer.imputation_strategies is not None - assert trainer.scaling_techniques is not None - -# Test load_data_file -def test_load_data_file(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - df.to_csv('/tmp/test.csv', index=False) - trainer.load_data_file('/tmp/test.csv') - assert trainer.testing_metrics is not None - assert 0 in trainer.testing_metrics - -# Test process_test_data_for_classification -def test_process_test_data_for_classification(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - trainer.testing_metrics = {0: df} - trainer.process_test_data_for_classification() - assert trainer.noise_test is not None - assert trainer.sua_mua_test is not None - -# Test apply_scaling_imputation -def test_apply_scaling_imputation(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - X_train = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) - X_val = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) - y_train = df['is_noise'] - y_val = df['is_noise'] - result = trainer.apply_scaling_imputation('median', trainer.scaling_techniques[0][1], X_train, X_val, y_train, y_val) - assert result is not None - -# Test get_classifier_search_space -def test_get_classifier_search_space(): - from sklearn.linear_model import LogisticRegression - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - model, param_space = trainer.get_classifier_search_space(LogisticRegression) - assert model is not None - assert param_space is not None - -# Test Objective Enum -def test_objective_enum(): - assert Objective.Noise == Objective(1) - assert Objective.SUA == Objective(2) - assert str(Objective.Noise) == "Objective.Noise" - assert str(Objective.SUA) == "Objective.SUA" - -# Test train_model function -def test_train_model(monkeypatch): - output_folder = '/tmp/output' - os.makedirs(output_folder, exist_ok=True) - df.to_csv('/tmp/metrics.csv', index=False) - - def mock_load_and_preprocess_full(self, path): - self.testing_metrics = {0: df} - self.process_test_data_for_classification() - - monkeypatch.setattr(CurationModelTrainer, 'load_and_preprocess_full', mock_load_and_preprocess_full) - - trainer = train_model('/tmp/metrics.csv', output_folder, 'is_noise') - assert trainer is not None - assert trainer.testing_metrics is not None - assert 0 in trainer.testing_metrics - -# Clean up temporary files -@pytest.fixture(scope="module", autouse=True) -def cleanup(request): - def remove_tmp(): - shutil.rmtree('/tmp', ignore_errors=True) - request.addfinalizer(remove_tmp) From f687c2c2fe9b70a970cfd39d6dd7b134c15e065f Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:20:32 +0100 Subject: [PATCH 5/6] docs --- src/spikeinterface/core/waveform_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 4543074872..98380e955f 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -685,10 +685,10 @@ def has_exceeding_spikes(sorting, recording) -> bool: Parameters ---------- - recording : BaseRecording - The recording object sorting : BaseSorting The sorting object + recording : BaseRecording + The recording object Returns ------- From b8c8fa83ba8695545b420d135c92f5167d7d2de1 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:54:59 +0100 Subject: [PATCH 6/6] Missed one --- .../postprocessing/tests/common_extension_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index bb2f5aaafd..52dbaf23d4 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -73,7 +73,7 @@ class instance is used for each. In this case, we have to set self.__class__.recording, self.__class__.sorting = get_dataset() self.__class__.sparsity = estimate_sparsity( - self.__class__.recording, self.__class__.sorting, method="radius", radius_um=20 + self.__class__.sorting, self.__class__.recording, method="radius", radius_um=20 ) self.__class__.cache_folder = create_cache_folder