Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 15, 2023
1 parent acd87c5 commit d692439
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
21 changes: 16 additions & 5 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,8 +1336,17 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um):
return channel_locations


def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0,
minimum_distance=20., max_iteration=100, distance_strict=False, seed=None):
def generate_unit_locations(
num_units,
channel_locations,
margin_um=20.0,
minimum_z=5.0,
maximum_z=40.0,
minimum_distance=20.0,
max_iteration=100,
distance_strict=False,
seed=None,
):
rng = np.random.default_rng(seed=seed)
units_locations = np.zeros((num_units, 3), dtype="float32")

Expand All @@ -1364,7 +1373,7 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimu
else:
# random only bad ones in the previous set
renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]

units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)
units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)
units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)
Expand All @@ -1374,8 +1383,10 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimu

if not solution_found:
if distance_strict:
raise ValueError(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} "
"You can use distance_strict=False or reduce minimum distance")
raise ValueError(
f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} "
"You can use distance_strict=False or reduce minimum distance"
)
else:
warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}")

Expand Down
21 changes: 13 additions & 8 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,24 +292,29 @@ def test_generate_single_fake_waveform():


def test_generate_unit_locations():

seed = 0

probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)
channel_locations = probe.contact_positions

num_units = 100
minimum_distance = 20.
unit_locations = generate_unit_locations(num_units, channel_locations,
margin_um=20.0, minimum_z=5.0, maximum_z=40.0,
minimum_distance=minimum_distance, max_iteration=500,
distance_strict=False, seed=seed)
minimum_distance = 20.0
unit_locations = generate_unit_locations(
num_units,
channel_locations,
margin_um=20.0,
minimum_z=5.0,
maximum_z=40.0,
minimum_distance=minimum_distance,
max_iteration=500,
distance_strict=False,
seed=seed,
)
distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)
dist_flat = np.triu(distances, k=1).flatten()
dist_flat = dist_flat[dist_flat>0]
dist_flat = dist_flat[dist_flat > 0]
assert np.all(dist_flat > minimum_distance)


# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
# ax.hist(dist_flat, bins = np.arange(0, 400, 10))
Expand Down

0 comments on commit d692439

Please sign in to comment.