Skip to content

Commit

Permalink
Merge pull request #2147 from samuelgarcia/improve_generate
Browse files Browse the repository at this point in the history
Add a minimum distance in generate_unit_locations.
  • Loading branch information
alejoe91 authored Nov 16, 2023
2 parents 6b170c5 + d692439 commit e7be6b6
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 8 deletions.
57 changes: 51 additions & 6 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,15 +1336,60 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um):
return channel_locations


def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, seed=None):
def generate_unit_locations(
num_units,
channel_locations,
margin_um=20.0,
minimum_z=5.0,
maximum_z=40.0,
minimum_distance=20.0,
max_iteration=100,
distance_strict=False,
seed=None,
):
rng = np.random.default_rng(seed=seed)
units_locations = np.zeros((num_units, 3), dtype="float32")
for dim in (0, 1):
lim0 = np.min(channel_locations[:, dim]) - margin_um
lim1 = np.max(channel_locations[:, dim]) + margin_um
units_locations[:, dim] = rng.uniform(lim0, lim1, size=num_units)

minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um
minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um

units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)
units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)
units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)

if minimum_distance is not None:
solution_found = False
renew_inds = None
for i in range(max_iteration):
distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)
inds0, inds1 = np.nonzero(distances < minimum_distance)
mask = inds0 != inds1
inds0 = inds0[mask]
inds1 = inds1[mask]

if inds0.size > 0:
if renew_inds is None:
renew_inds = np.unique(inds0)
else:
# random only bad ones in the previous set
renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]

units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)
units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)
units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)
else:
solution_found = True
break

if not solution_found:
if distance_strict:
raise ValueError(
f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} "
"You can use distance_strict=False or reduce minimum distance"
)
else:
warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}")

return units_locations


Expand All @@ -1369,7 +1414,7 @@ def generate_ground_truth_recording(
upsample_vector=None,
generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0),
noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20),
generate_templates_kwargs=dict(),
dtype="float32",
seed=None,
Expand Down
41 changes: 39 additions & 2 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np

from spikeinterface.core import load_extractor, extract_waveforms

from probeinterface import generate_multi_columns_probe
from spikeinterface.core.generate import (
generate_recording,
generate_sorting,
Expand Down Expand Up @@ -289,6 +291,40 @@ def test_generate_single_fake_waveform():
# plt.show()


def test_generate_unit_locations():
seed = 0

probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)
channel_locations = probe.contact_positions

num_units = 100
minimum_distance = 20.0
unit_locations = generate_unit_locations(
num_units,
channel_locations,
margin_um=20.0,
minimum_z=5.0,
maximum_z=40.0,
minimum_distance=minimum_distance,
max_iteration=500,
distance_strict=False,
seed=seed,
)
distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)
dist_flat = np.triu(distances, k=1).flatten()
dist_flat = dist_flat[dist_flat > 0]
assert np.all(dist_flat > minimum_distance)

# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
# ax.hist(dist_flat, bins = np.arange(0, 400, 10))
# fig, ax = plt.subplots()
# from probeinterface.plotting import plot_probe
# plot_probe(probe, ax=ax)
# ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)
# plt.show()


def test_generate_templates():
seed = 0

Expand All @@ -297,7 +333,7 @@ def test_generate_templates():
num_units = 10
margin_um = 15.0
channel_locations = generate_channel_locations(num_chans, num_columns, 20.0)
unit_locations = generate_unit_locations(num_units, channel_locations, margin_um, seed)
unit_locations = generate_unit_locations(num_units, channel_locations, margin_um=margin_um, seed=seed)

sampling_frequency = 30000.0
ms_before = 1.0
Expand Down Expand Up @@ -436,7 +472,8 @@ def test_generate_ground_truth_recording():
# test_noise_generator_consistency_after_dump(strategy, None)
# test_generate_recording()
# test_generate_single_fake_waveform()
test_generate_unit_locations()
# test_generate_templates()
# test_inject_templates()
# test_generate_ground_truth_recording()
test_generate_sorting_with_spikes_on_borders()
# test_generate_sorting_with_spikes_on_borders()

0 comments on commit e7be6b6

Please sign in to comment.