Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add session displacement generation #3231

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
33 changes: 25 additions & 8 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def generate_sorting(
add_spikes_on_borders=False,
num_spikes_per_border=3,
border_size_samples=20,
extra_outputs=False,
seed=None,
):
"""
Expand Down Expand Up @@ -135,10 +136,14 @@ def generate_sorting(
num_segments = len(durations)
unit_ids = np.arange(num_units)

extra_outputs_dict = {
"firing_rates": [],
}

spikes = []
for segment_index in range(num_segments):
num_samples = int(sampling_frequency * durations[segment_index])
samples, labels = synthesize_poisson_spike_vector(
samples, labels, firing_rates_array = synthesize_poisson_spike_vector(
num_units=num_units,
sampling_frequency=sampling_frequency,
duration=durations[segment_index],
Expand Down Expand Up @@ -172,12 +177,17 @@ def generate_sorting(
)
spikes.append(spikes_on_borders)

extra_outputs_dict["firing_rates"].append(firing_rates_array)

spikes = np.concatenate(spikes)
spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))]

sorting = NumpySorting(spikes, sampling_frequency, unit_ids)

return sorting
if extra_outputs:
return sorting, extra_outputs_dict
else:
return sorting


def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None):
Expand Down Expand Up @@ -776,7 +786,7 @@ def synthesize_poisson_spike_vector(
unit_indices = unit_indices[sort_indices]
spike_frames = spike_frames[sort_indices]

return spike_frames, unit_indices
return spike_frames, unit_indices, firing_rates


def synthesize_random_firings(
Expand Down Expand Up @@ -2188,12 +2198,19 @@ def generate_ground_truth_recording(
parent_recording=noise_rec,
upsample_vector=upsample_vector,
)
recording.annotate(is_filtered=True)
recording.set_probe(probe, in_place=True)
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)

setup_inject_templates_recording(recording, probe)
recording.name = "GroundTruthRecording"
sorting.name = "GroundTruthSorting"

return recording, sorting


def setup_inject_templates_recording(recording: BaseRecording, probe: Probe) -> None:
"""
Convenience function to modify a generated
recording in-place with annotation and probe details
"""
recording.annotate(is_filtered=True)
recording.set_probe(probe, in_place=True)
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)
124 changes: 105 additions & 19 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

"""

from __future__ import annotations
import numpy as np

from probeinterface import generate_multi_columns_probe
Expand All @@ -21,6 +22,7 @@
)
from .drift_tools import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording
from .noise_tools import generate_noise
from probeinterface import Probe


# this should be moved in probeinterface but later
Expand Down Expand Up @@ -181,7 +183,7 @@ def generate_displacement_vector(
duration : float
Duration of the displacement vector in seconds
unit_locations : np.array
The unit location with shape (num_units, 3)
The unit location with shape (num_units, 2)
displacement_sampling_frequency : float, default: 5.
The sampling frequency of the displacement vector
drift_start_um : list of float, default: [0, 20.]
Expand Down Expand Up @@ -238,22 +240,64 @@ def generate_displacement_vector(
if non_rigid_gradient is None:
displacement_unit_factor[:, m] = 1
else:
gradient_direction = drift_stop_um - drift_start_um
gradient_direction /= np.linalg.norm(gradient_direction)

proj = np.dot(unit_locations, gradient_direction).squeeze()
factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj))
if non_rigid_gradient < 0:
# reverse
factors = 1 - factors
f = np.abs(non_rigid_gradient)
displacement_unit_factor[:, m] = factors * (1 - f) + f
displacement_unit_factor[:, m] = calculate_displacement_unit_factor(
non_rigid_gradient, unit_locations, drift_start_um, drift_stop_um
)

displacement_vectors = np.concatenate(displacement_vectors, axis=2)

return displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps


def calculate_displacement_unit_factor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this could be called something like "simulate_linear_gradient_drift"? i was a bit confused reading it, but it seems to be generating drift which is 0 at the top of the probe and something not zero at the bottom?

maybe someone can help explain what exactly

displacement_unit_factor = factors * (1 - f) + f

ends up producing... is it like there is some global drift plus per-unit linear drift?

non_rigid_gradient: float, unit_locations: np.array, drift_start_um: np.array, drift_stop_um: np.array
) -> np.array:
"""
In the case of introducing non-rigid drift, a set of scaling
factors (one per unit) is generated for scaling the displacement
as a function of unit position.

The projections of the gradient vector (x, y)
and unit locations (x, y) are normalised to range between
0 and 1 (i.e. based on relative location to the gradient).
These factors are scaled by `non_rigid_gradient`.

Parameters
----------

non_rigid_gradient : float
A number in the range [0, 1] by which to scale the scaling factors
that are based on unit location. This sets the weighting given to the factors
based on unit locations. When 1, the factors will all equal 1 (no effect),
when 0, the scaling factor based on unit location will be used directly.
unit_locations : np.array
The unit location with shape (num_units, 2)
drift_start_um : np.array
The start boundary of the motion in the x and y direction.
drift_stop_um : np.array
The stop boundary of the motion in the x and y direction.

Returns
-------
displacement_unit_factor : np.array
An array of scaling factors (one per unit) by which
to scale the displacement.
"""
gradient_direction = drift_stop_um - drift_start_um
gradient_direction /= np.linalg.norm(gradient_direction)

proj = np.dot(unit_locations, gradient_direction).squeeze()
factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj))

if non_rigid_gradient < 0: # reverse
factors = 1 - factors

f = np.abs(non_rigid_gradient)
displacement_unit_factor = factors * (1 - f) + f

return displacement_unit_factor


def generate_drifting_recording(
num_units=250,
duration=600.0,
Expand Down Expand Up @@ -352,12 +396,9 @@ def generate_drifting_recording(
rng = np.random.default_rng(seed=seed)

# probe
if generate_probe_kwargs is None:
generate_probe_kwargs = _toy_probes[probe_name]
probe = generate_multi_columns_probe(**generate_probe_kwargs)
num_channels = probe.get_contact_count()
probe.set_device_channel_indices(np.arange(num_channels))
probe = generate_probe(generate_probe_kwargs, probe_name)
channel_locations = probe.contact_positions

# import matplotlib.pyplot as plt
# import probeinterface.plotting
# fig, ax = plt.subplots()
Expand Down Expand Up @@ -385,9 +426,7 @@ def generate_drifting_recording(
unit_displacements[:, :, direction] += m

# unit_params need to be fixed before the displacement steps
generate_templates_kwargs = generate_templates_kwargs.copy()
unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed)
generate_templates_kwargs["unit_params"] = unit_params
generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)

# generate templates
templates_array = generate_templates(
Expand Down Expand Up @@ -479,3 +518,50 @@ def generate_drifting_recording(
return static_recording, drifting_recording, sorting, extra_infos
else:
return static_recording, drifting_recording, sorting


def generate_probe(generate_probe_kwargs: dict, probe_name: str | None = None) -> Probe:
"""
Generate a probe for use in certain ground-truth recordings.

Parameters
----------

generate_probe_kwargs : dict
The kwargs to pass to `generate_multi_columns_probe()`
probe_name : str | None
The probe type if generate_probe_kwargs is None.
"""
if generate_probe_kwargs is None:
assert probe_name is not None, "`probe_name` must be set if `generate_probe_kwargs` is `None`."
generate_probe_kwargs = _toy_probes[probe_name]
probe = generate_multi_columns_probe(**generate_probe_kwargs)
num_channels = probe.get_contact_count()
probe.set_device_channel_indices(np.arange(num_channels))

return probe


def fix_generate_templates_kwargs(generate_templates_kwargs: dict, num_units: int, seed: int) -> dict:
"""
Fix the generate_template_kwargs such that the same units are created
across calls to `generate_template`. We must explicitly pre-set
the parameters for each unit, done in `_ensure_unit_params()`.

Parameters
----------

generate_templates_kwargs : dict
These kwargs will have the "unit_params" entry edited such that the
parameters are explicitly set for each unit to create (rather than
generated randomly on the fly).
num_units : int
Number of units to fix the kwargs for
seed : int
Random seed.
"""
generate_templates_kwargs = generate_templates_kwargs.copy()
unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed)
generate_templates_kwargs["unit_params"] = unit_params

return generate_templates_kwargs
Loading
Loading