Skip to content

Commit

Permalink
Merge pull request SpikeInterface#2263 from alejoe91/optimize_motion
Browse files Browse the repository at this point in the history
Fix memory leak in lsmr solver and optimize correct_motion
  • Loading branch information
samuelgarcia authored Dec 4, 2023
2 parents a39e329 + 8852629 commit bf0b055
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 32 deletions.
62 changes: 31 additions & 31 deletions src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
exclude_sweep_ms=0.1,
radius_um=50,
),
"select_kwargs": None,
"select_kwargs": dict(),
"localize_peaks_kwargs": dict(
method="monopolar_triangulation",
radius_um=75.0,
Expand Down Expand Up @@ -83,7 +83,7 @@
exclude_sweep_ms=0.1,
radius_um=50,
),
"select_kwargs": None,
"select_kwargs": dict(),
"localize_peaks_kwargs": dict(
method="center_of_mass",
radius_um=75.0,
Expand Down Expand Up @@ -111,7 +111,7 @@
exclude_sweep_ms=0.1,
radius_um=50,
),
"select_kwargs": None,
"select_kwargs": dict(),
"localize_peaks_kwargs": dict(
method="grid_convolution",
radius_um=40.0,
Expand Down Expand Up @@ -157,7 +157,7 @@ def correct_motion(
folder=None,
output_motion_info=False,
detect_kwargs={},
select_kwargs=None,
select_kwargs={},
localize_peaks_kwargs={},
estimate_motion_kwargs={},
interpolate_motion_kwargs={},
Expand Down Expand Up @@ -241,27 +241,42 @@ def correct_motion(
# get preset params and update if necessary
params = motion_options_preset[preset]
detect_kwargs = dict(params["detect_kwargs"], **detect_kwargs)
if params["select_kwargs"] is None:
select_kwargs = None
else:
select_kwargs = dict(params["select_kwargs"], **select_kwargs)
select_kwargs = dict(params["select_kwargs"], **select_kwargs)
localize_peaks_kwargs = dict(params["localize_peaks_kwargs"], **localize_peaks_kwargs)
estimate_motion_kwargs = dict(params["estimate_motion_kwargs"], **estimate_motion_kwargs)
interpolate_motion_kwargs = dict(params["interpolate_motion_kwargs"], **interpolate_motion_kwargs)
do_selection = len(select_kwargs) > 0

# params
parameters = dict(
detect_kwargs=detect_kwargs,
select_kwargs=select_kwargs,
localize_peaks_kwargs=localize_peaks_kwargs,
estimate_motion_kwargs=estimate_motion_kwargs,
interpolate_motion_kwargs=interpolate_motion_kwargs,
job_kwargs=job_kwargs,
sampling_frequency=recording.sampling_frequency,
)

if output_motion_info:
motion_info = {}
else:
motion_info = None

job_kwargs = fix_job_kwargs(job_kwargs)

noise_levels = get_noise_levels(recording, return_scaled=False)

if select_kwargs is None:
# maybe do this directly in the folder when not None
gather_mode = "memory"
if folder is not None:
folder = Path(folder)
folder.mkdir(exist_ok=True, parents=True)

(folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8")
if recording.check_serializability("json"):
recording.dump_to_json(folder / "recording.json")

if not do_selection:
# maybe do this directly in the folder when not None, but might be slow on external storage
gather_mode = "memory"
# node detect
method = detect_kwargs.pop("method", "locally_exclusive")
method_class = detect_peak_methods[method]
Expand All @@ -281,6 +296,7 @@ def correct_motion(
job_kwargs,
job_name="detect and localize",
gather_mode=gather_mode,
gather_kwargs=None,
squeeze_output=False,
folder=None,
names=None,
Expand All @@ -307,6 +323,9 @@ def correct_motion(
select_peaks=t2 - t1,
localize_peaks=t3 - t2,
)
if folder is not None:
np.save(folder / "peaks.npy", peaks)
np.save(folder / "peak_locations.npy", peak_locations)

t0 = time.perf_counter()
motion, temporal_bins, spatial_bins = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs)
Expand All @@ -318,29 +337,10 @@ def correct_motion(
)

if folder is not None:
folder = Path(folder)
folder.mkdir(exist_ok=True, parents=True)

# params and run times
parameters = dict(
detect_kwargs=detect_kwargs,
select_kwargs=select_kwargs,
localize_peaks_kwargs=localize_peaks_kwargs,
estimate_motion_kwargs=estimate_motion_kwargs,
interpolate_motion_kwargs=interpolate_motion_kwargs,
job_kwargs=job_kwargs,
sampling_frequency=recording.sampling_frequency,
)
(folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8")
(folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8")
if recording.check_serializability("json"):
recording.dump_to_json(folder / "recording.json")

np.save(folder / "peaks.npy", peaks)
np.save(folder / "peak_locations.npy", peak_locations)
np.save(folder / "temporal_bins.npy", temporal_bins)
np.save(folder / "motion.npy", motion)
np.save(folder / "peak_locations.npy", peak_locations)
if spatial_bins is not None:
np.save(folder / "spatial_bins.npy", spatial_bins)

Expand Down
6 changes: 5 additions & 1 deletion src/spikeinterface/sortingcomponents/motion_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class DecentralizedRegistration:
pairwise_displacement_method: "conv" or "phase_cross_correlation"
How to estimate the displacement in the pairwise matrix.
max_displacement_um: float
Maximum possible discplacement in micrometers.
Maximum possible displacement in micrometers.
weight_scale: "linear" or "exp"
For parwaise displacement, how to to rescale the associated weight matrix.
error_sigma: float, default: 0.2
Expand Down Expand Up @@ -1039,6 +1039,7 @@ def jac(p):
displacement = p

elif convergence_method == "lsmr":
import gc
from scipy import sparse
from scipy.stats import zscore

Expand Down Expand Up @@ -1170,6 +1171,9 @@ def jac(p):

# warm start next iteration
p0 = displacement
# Cleanup lsmr memory (see https://stackoverflow.com/questions/56147713/memory-leak-in-scipy)
# TODO: check if this gets fixed in scipy
gc.collect()

displacement = displacement.reshape(B, T).T
else:
Expand Down

0 comments on commit bf0b055

Please sign in to comment.