From 257950d5859521730ed1da746b6fd32b7b6335bb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jun 2024 06:36:40 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 4 ++- .../benchmark/benchmark_motion_estimation.py | 20 ++++++------- .../benchmark_motion_interpolation.py | 4 +-- .../tests/test_benchmark_motion_estimation.py | 1 + .../test_benchmark_motion_interpolation.py | 1 + .../sortingcomponents/motion_utils.py | 4 +-- .../tests/test_motion_utils.py | 11 ++++--- src/spikeinterface/widgets/motion.py | 30 ++++++++----------- .../widgets/tests/test_widgets.py | 21 +++++-------- 9 files changed, 41 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 1a064dcb31..b5df0f1059 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -314,7 +314,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): motion_info = load_motion_info(motion_folder) motion = motion_info["motion"] - max_motion = max(np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement))) + max_motion = max( + np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)) + ) merging_params["maximum_distance_um"] = max(50, 2 * max_motion) # peak_sign = params['detection'].get('peak_sign', 'neg') diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 1408a5cb32..55ef21de9d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -55,13 +55,9 @@ def get_gt_motion_from_unit_displacement( gt_displacement[t, :] = f(spatial_bins_um) gt_motion = Motion( - gt_displacement, - temporal_bins_s, - spatial_bins_um, - direction="xyz"[direction_dim], - interpolation_method="linear" + gt_displacement, temporal_bins_s, spatial_bins_um, direction="xyz"[direction_dim], interpolation_method="linear" ) - + return gt_motion @@ -102,9 +98,7 @@ def run(self, **job_kwargs): t2 = time.perf_counter() peak_locations = localize_peaks(self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs) t3 = time.perf_counter() - motion = estimate_motion( - self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"] - ) + motion = estimate_motion(self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"]) t4 = time.perf_counter() step_run_times = dict( @@ -263,8 +257,12 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): aspect="auto", interpolation="nearest", origin="lower", - extent=(motion.temporal_bins_s[0][0], motion.temporal_bins_s[0][-1], - motion.spatial_bins_um[0], motion.spatial_bins_um[-1]), + extent=( + motion.temporal_bins_s[0][0], + motion.temporal_bins_s[0][-1], + motion.spatial_bins_um[0], + motion.spatial_bins_um[-1], + ), ) plt.colorbar(im, ax=ax, label="error") ax.set_ylabel("depth (um)") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index 5688d2eaf3..a6ff05fc55 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -44,9 +44,7 @@ def run(self, **job_kwargs): recording = self.drifting_recording elif self.params["recording_source"] == "corrected": correct_motion_kwargs = self.params["correct_motion_kwargs"] - recording = InterpolateMotionRecording( - self.drifting_recording, self.motion, **correct_motion_kwargs - ) + recording = InterpolateMotionRecording(self.drifting_recording, self.motion, **correct_motion_kwargs) else: raise ValueError("recording_source") diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py index 14a5fe9138..526cc2e92f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -70,6 +70,7 @@ def test_benchmark_motion_estimaton(create_cache_folder): study.plot_summary_errors() import matplotlib.pyplot as plt + plt.show() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index 07eb35b693..6d80d027f2 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -134,6 +134,7 @@ def test_benchmark_motion_interpolation(create_cache_folder): study.plot_sorting_accuracy(mode="depth", mode_best_merge=True) import matplotlib.pyplot as plt + plt.show() diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index d4f0bb93b5..26d4b35b1a 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -220,11 +220,11 @@ def __eq__(self, other): return False return True - + def copy(self): return Motion( self.displacement.copy(), self.temporal_bins_s.copy(), self.spatial_bins_um.copy(), - interpolation_method=self.interpolation_method + interpolation_method=self.interpolation_method, ) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py index 1542c8531a..0b67be39c0 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -14,29 +14,28 @@ def make_fake_motion(): - displacement_sampling_frequency = 5. - spatial_bins_um = np.array([100.0, 200.0, 300., 400.]) + displacement_sampling_frequency = 5.0 + spatial_bins_um = np.array([100.0, 200.0, 300.0, 400.0]) displacement_vector = make_one_displacement_vector( drift_mode="zigzag", duration=50.0, amplitude_factor=1.0, displacement_sampling_frequency=displacement_sampling_frequency, - period_s=25., + period_s=25.0, ) temporal_bins_s = np.arange(displacement_vector.size) / displacement_sampling_frequency displacement = np.zeros((temporal_bins_s.size, spatial_bins_um.size)) - + n = spatial_bins_um.size for i in range(n): - displacement[:, i] = displacement_vector * ((i +1 ) / n) + displacement[:, i] = displacement_vector * ((i + 1) / n) motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y") return motion - def test_Motion(): temporal_bins_s = np.arange(0.0, 10.0, 1.0) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index b98e619bf7..dcb7b26f7e 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -4,6 +4,7 @@ from .base import BaseWidget, to_attr + class MotionWidget(BaseWidget): """ Plot the Motion object @@ -18,6 +19,7 @@ class MotionWidget(BaseWidget): How to plot map or lines. "auto" make it automatic if the number of depth is too high. """ + def __init__( self, motion, @@ -26,10 +28,11 @@ def __init__( motion_lim=None, backend=None, **backend_kwargs, - - ): + ): if isinstance(motion, dict): - raise ValueError("The API has changed, plot_motion() used Motion object now, maybe you want plot_motion_info(motion_info)") + raise ValueError( + "The API has changed, plot_motion() used Motion object now, maybe you want plot_motion_info(motion_info)" + ) if segment_index is None: if len(motion.displacement) == 1: @@ -43,7 +46,7 @@ def __init__( mode=mode, ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt @@ -59,7 +62,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - displacement = motion.displacement[dp.segment_index] temporal_bins_s = motion.temporal_bins_s[dp.segment_index] depth = motion.spatial_bins_um @@ -69,7 +71,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: motion_lim = dp.motion_lim - ax = self.ax fig = self.figure if dp.mode == "line": @@ -84,10 +85,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): aspect="auto", origin="lower", extent=(temporal_bins_s[0], temporal_bins_s[-1], depth[0], depth[-1]), - cmap="PiYG" + cmap="PiYG", ) im.set_clim(-motion_lim, motion_lim) - + cbar = fig.colorbar(im) cbar.ax.set_ylabel("motion [um]") ax.set_xlabel("Times [s]") @@ -106,7 +107,7 @@ class MotionInfoWidget(BaseWidget): ---------- motion_info : dict The motion info return by correct_motion() or load back with load_motion_info() - segment_index: + segment_index: recording : RecordingExtractor, default: None The recording extractor object (only used to get "real" times) @@ -145,7 +146,7 @@ def __init__( backend=None, **backend_kwargs, ): - + motion = motion_info["motion"] if segment_index is None: if len(motion.displacement) == 1: @@ -193,7 +194,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): motion = dp.motion - displacement = motion.displacement[dp.segment_index] temporal_bins_s = motion.temporal_bins_s[dp.segment_index] spatial_bins_um = motion.spatial_bins_um @@ -203,7 +203,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: motion_lim = dp.motion_lim - is_rigid = displacement.shape[1] == 1 gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.5) @@ -223,12 +222,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # temporal_bins_plot = dp.temporal_bins + dp.times[0] x = dp.times[dp.peaks["sample_index"]] - corrected_location = correct_motion_on_peaks( - dp.peaks, - dp.peak_locations, - dp.recording, - dp.motion - ) + corrected_location = correct_motion_on_peaks(dp.peaks, dp.peak_locations, dp.recording, dp.motion) dim = ["x", "y", "z"][dp.motion.dim] y = dp.peak_locations[motion.direction] diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 3e3e432817..0198e24626 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -579,47 +579,40 @@ def test_plot_multicomparison(self): _, axes = plt.subplots(len(mcmp.object_list), 1) sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) - + def test_plot_motion(self): from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion + motion = make_fake_motion() possible_backends = list(sw.MotionWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_motion(motion, backend=backend, mode='line') - sw.plot_motion(motion, backend=backend, mode='map') + sw.plot_motion(motion, backend=backend, mode="line") + sw.plot_motion(motion, backend=backend, mode="map") def test_plot_motion_info(self): from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion - motion = make_fake_motion() rng = np.random.default_rng(seed=2205) peak_locations = np.zeros(self.peaks.size, dtype=[("x", "float64"), ("y", "float64")]) - peak_locations['y'] = rng.uniform(motion.spatial_bins_um[0], - motion.spatial_bins_um[-1], - size=self.peaks.size) - + peak_locations["y"] = rng.uniform(motion.spatial_bins_um[0], motion.spatial_bins_um[-1], size=self.peaks.size) + motion_info = dict( motion=motion, - parameters=dict(sampling_frequency=30000.), + parameters=dict(sampling_frequency=30000.0), run_times=dict(), peaks=self.peaks, peak_locations=peak_locations, ) - possible_backends = list(sw.MotionWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_motion_info(motion_info, recording=self.recording, backend=backend) - - - - if __name__ == "__main__": # unittest.main() import matplotlib.pyplot as plt