Skip to content

Commit

Permalink
conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Jun 19, 2024
2 parents 225269d + 257950d commit ee6af73
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 48 deletions.
4 changes: 3 additions & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

motion_info = load_motion_info(motion_folder)
motion = motion_info["motion"]
max_motion = max(np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)))
max_motion = max(
np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement))
)
merging_params["maximum_distance_um"] = max(50, 2 * max_motion)

# peak_sign = params['detection'].get('peak_sign', 'neg')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,9 @@ def get_gt_motion_from_unit_displacement(
gt_displacement[t, :] = f(spatial_bins_um)

gt_motion = Motion(
gt_displacement,
temporal_bins_s,
spatial_bins_um,
direction="xyz"[direction_dim],
interpolation_method="linear"
gt_displacement, temporal_bins_s, spatial_bins_um, direction="xyz"[direction_dim], interpolation_method="linear"
)

return gt_motion


Expand Down Expand Up @@ -102,9 +98,7 @@ def run(self, **job_kwargs):
t2 = time.perf_counter()
peak_locations = localize_peaks(self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs)
t3 = time.perf_counter()
motion = estimate_motion(
self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"]
)
motion = estimate_motion(self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"])
t4 = time.perf_counter()

step_run_times = dict(
Expand Down Expand Up @@ -263,8 +257,12 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None):
aspect="auto",
interpolation="nearest",
origin="lower",
extent=(motion.temporal_bins_s[0][0], motion.temporal_bins_s[0][-1],
motion.spatial_bins_um[0], motion.spatial_bins_um[-1]),
extent=(
motion.temporal_bins_s[0][0],
motion.temporal_bins_s[0][-1],
motion.spatial_bins_um[0],
motion.spatial_bins_um[-1],
),
)
plt.colorbar(im, ax=ax, label="error")
ax.set_ylabel("depth (um)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def run(self, **job_kwargs):
recording = self.drifting_recording
elif self.params["recording_source"] == "corrected":
correct_motion_kwargs = self.params["correct_motion_kwargs"]
recording = InterpolateMotionRecording(
self.drifting_recording, self.motion, **correct_motion_kwargs
)
recording = InterpolateMotionRecording(self.drifting_recording, self.motion, **correct_motion_kwargs)
else:
raise ValueError("recording_source")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_benchmark_motion_estimaton(create_cache_folder):
study.plot_summary_errors()

import matplotlib.pyplot as plt

plt.show()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def test_benchmark_motion_interpolation(create_cache_folder):
study.plot_sorting_accuracy(mode="depth", mode_best_merge=True)

import matplotlib.pyplot as plt

plt.show()


Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ def __eq__(self, other):
return False

return True

def copy(self):
return Motion(
self.displacement.copy(),
self.temporal_bins_s.copy(),
self.spatial_bins_um.copy(),
interpolation_method=self.interpolation_method
interpolation_method=self.interpolation_method,
)
11 changes: 5 additions & 6 deletions src/spikeinterface/sortingcomponents/tests/test_motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,28 @@


def make_fake_motion():
displacement_sampling_frequency = 5.
spatial_bins_um = np.array([100.0, 200.0, 300., 400.])
displacement_sampling_frequency = 5.0
spatial_bins_um = np.array([100.0, 200.0, 300.0, 400.0])

displacement_vector = make_one_displacement_vector(
drift_mode="zigzag",
duration=50.0,
amplitude_factor=1.0,
displacement_sampling_frequency=displacement_sampling_frequency,
period_s=25.,
period_s=25.0,
)
temporal_bins_s = np.arange(displacement_vector.size) / displacement_sampling_frequency
displacement = np.zeros((temporal_bins_s.size, spatial_bins_um.size))

n = spatial_bins_um.size
for i in range(n):
displacement[:, i] = displacement_vector * ((i +1 ) / n)
displacement[:, i] = displacement_vector * ((i + 1) / n)

motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y")

return motion



def test_Motion():

temporal_bins_s = np.arange(0.0, 10.0, 1.0)
Expand Down
26 changes: 15 additions & 11 deletions src/spikeinterface/widgets/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .base import BaseWidget, to_attr


class MotionWidget(BaseWidget):
"""
Plot the Motion object
Expand All @@ -18,6 +19,7 @@ class MotionWidget(BaseWidget):
How to plot map or lines.
"auto" make it automatic if the number of depth is too high.
"""

def __init__(
self,
motion,
Expand All @@ -26,10 +28,11 @@ def __init__(
motion_lim=None,
backend=None,
**backend_kwargs,

):
):
if isinstance(motion, dict):
raise ValueError("The API has changed, plot_motion() used Motion object now, maybe you want plot_motion_info(motion_info)")
raise ValueError(
"The API has changed, plot_motion() used Motion object now, maybe you want plot_motion_info(motion_info)"
)

if segment_index is None:
if len(motion.displacement) == 1:
Expand All @@ -43,7 +46,7 @@ def __init__(
mode=mode,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
Expand All @@ -59,7 +62,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)


displacement = motion.displacement[dp.segment_index]
temporal_bins_s = motion.temporal_bins_s[dp.segment_index]
depth = motion.spatial_bins_um
Expand All @@ -69,7 +71,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
else:
motion_lim = dp.motion_lim


ax = self.ax
fig = self.figure
if dp.mode == "line":
Expand All @@ -84,10 +85,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
aspect="auto",
origin="lower",
extent=(temporal_bins_s[0], temporal_bins_s[-1], depth[0], depth[-1]),
cmap="PiYG"
cmap="PiYG",
)
im.set_clim(-motion_lim, motion_lim)

cbar = fig.colorbar(im)
cbar.ax.set_ylabel("motion [um]")
ax.set_xlabel("Times [s]")
Expand All @@ -106,8 +107,13 @@ class MotionInfoWidget(BaseWidget):
----------
motion_info : dict
The motion info return by correct_motion() or load back with load_motion_info()
<<<<<<< HEAD
segment_index : int, default: None
The segment index to display.
=======
segment_index:
>>>>>>> 257950d5859521730ed1da746b6fd32b7b6335bb
recording : RecordingExtractor, default: None
The recording extractor object (only used to get "real" times)
segment_index : int, default: 0
Expand Down Expand Up @@ -145,7 +151,7 @@ def __init__(
backend=None,
**backend_kwargs,
):

motion = motion_info["motion"]
if segment_index is None:
if len(motion.displacement) == 1:
Expand Down Expand Up @@ -192,7 +198,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

motion = dp.motion


displacement = motion.displacement[dp.segment_index]
temporal_bins_s = motion.temporal_bins_s[dp.segment_index]
spatial_bins_um = motion.spatial_bins_um
Expand All @@ -202,7 +207,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
else:
motion_lim = dp.motion_lim


is_rigid = displacement.shape[1] == 1

gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.5)
Expand Down
21 changes: 7 additions & 14 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,47 +579,40 @@ def test_plot_multicomparison(self):

_, axes = plt.subplots(len(mcmp.object_list), 1)
sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes)

def test_plot_motion(self):
from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion

motion = make_fake_motion()

possible_backends = list(sw.MotionWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_motion(motion, backend=backend, mode='line')
sw.plot_motion(motion, backend=backend, mode='map')
sw.plot_motion(motion, backend=backend, mode="line")
sw.plot_motion(motion, backend=backend, mode="map")

def test_plot_motion_info(self):
from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion


motion = make_fake_motion()
rng = np.random.default_rng(seed=2205)
peak_locations = np.zeros(self.peaks.size, dtype=[("x", "float64"), ("y", "float64")])
peak_locations['y'] = rng.uniform(motion.spatial_bins_um[0],
motion.spatial_bins_um[-1],
size=self.peaks.size)

peak_locations["y"] = rng.uniform(motion.spatial_bins_um[0], motion.spatial_bins_um[-1], size=self.peaks.size)

motion_info = dict(
motion=motion,
parameters=dict(sampling_frequency=30000.),
parameters=dict(sampling_frequency=30000.0),
run_times=dict(),
peaks=self.peaks,
peak_locations=peak_locations,
)


possible_backends = list(sw.MotionWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_motion_info(motion_info, recording=self.recording, backend=backend)






if __name__ == "__main__":
# unittest.main()
import matplotlib.pyplot as plt
Expand Down

0 comments on commit ee6af73

Please sign in to comment.