Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a minimum distance in generate_unit_locations. #2147

Merged
merged 3 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,15 +1336,49 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um):
return channel_locations


def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, seed=None):
def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0,
minimum_distance=20., max_iteration=100, distance_strict=False, seed=None):
rng = np.random.default_rng(seed=seed)
units_locations = np.zeros((num_units, 3), dtype="float32")
for dim in (0, 1):
lim0 = np.min(channel_locations[:, dim]) - margin_um
lim1 = np.max(channel_locations[:, dim]) + margin_um
units_locations[:, dim] = rng.uniform(lim0, lim1, size=num_units)

minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um
minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um

units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)
units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)
units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)

if minimum_distance is not None:
solution_found = False
renew_inds = None
for i in range(max_iteration):
distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)
inds0, inds1 = np.nonzero(distances < minimum_distance)
mask = inds0 != inds1
inds0 = inds0[mask]
inds1 = inds1[mask]

if inds0.size > 0:
if renew_inds is None:
renew_inds = np.unique(inds0)
else:
# random only bad ones in the previous set
renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]

units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)
units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)
units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)
else:
solution_found = True
break

if not solution_found:
if distance_strict:
raise ValueError(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} "
"You can use distance_strict=False or reduce minimum distance")
else:
warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}")

return units_locations


Expand All @@ -1369,7 +1403,7 @@ def generate_ground_truth_recording(
upsample_vector=None,
generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0),
noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20),
generate_templates_kwargs=dict(),
dtype="float32",
seed=None,
Expand Down
36 changes: 34 additions & 2 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np

from spikeinterface.core import load_extractor, extract_waveforms

from probeinterface import generate_multi_columns_probe
from spikeinterface.core.generate import (
generate_recording,
generate_sorting,
Expand Down Expand Up @@ -289,6 +291,35 @@ def test_generate_single_fake_waveform():
# plt.show()


def test_generate_unit_locations():

seed = 0

probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)
channel_locations = probe.contact_positions

num_units = 100
minimum_distance = 20.
unit_locations = generate_unit_locations(num_units, channel_locations,
margin_um=20.0, minimum_z=5.0, maximum_z=40.0,
minimum_distance=minimum_distance, max_iteration=500,
distance_strict=False, seed=seed)
distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)
dist_flat = np.triu(distances, k=1).flatten()
dist_flat = dist_flat[dist_flat>0]
assert np.all(dist_flat > minimum_distance)


# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
# ax.hist(dist_flat, bins = np.arange(0, 400, 10))
# fig, ax = plt.subplots()
# from probeinterface.plotting import plot_probe
# plot_probe(probe, ax=ax)
# ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)
# plt.show()


def test_generate_templates():
seed = 0

Expand All @@ -297,7 +328,7 @@ def test_generate_templates():
num_units = 10
margin_um = 15.0
channel_locations = generate_channel_locations(num_chans, num_columns, 20.0)
unit_locations = generate_unit_locations(num_units, channel_locations, margin_um, seed)
unit_locations = generate_unit_locations(num_units, channel_locations, margin_um=margin_um, seed=seed)

sampling_frequency = 30000.0
ms_before = 1.0
Expand Down Expand Up @@ -436,7 +467,8 @@ def test_generate_ground_truth_recording():
# test_noise_generator_consistency_after_dump(strategy, None)
# test_generate_recording()
# test_generate_single_fake_waveform()
test_generate_unit_locations()
# test_generate_templates()
# test_inject_templates()
# test_generate_ground_truth_recording()
test_generate_sorting_with_spikes_on_borders()
# test_generate_sorting_with_spikes_on_borders()
Loading