Skip to content

Commit

Permalink
Merge pull request #231 from khl02007/get_scaled_traces
Browse files Browse the repository at this point in the history
Get scaled traces for `SpikeSortingView`
  • Loading branch information
magland authored Jul 10, 2024
2 parents 3679638 + 7b22920 commit b05ecba
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 20 deletions.
8 changes: 7 additions & 1 deletion examples/example_average_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ def extract_snippets(*, traces: np.ndarray, times: np.ndarray, snippet_len: Tupl


def compute_average_waveform(*, recording: si.BaseRecording, sorting: si.BaseSorting, unit_id: int):
traces = recording.get_traces(segment_index=0)
if hasattr(recording, 'has_scaleable_traces') and callable(getattr(recording, 'has_scaleable_traces')):
scalable = recording.has_scaleable_traces()
elif hasattr(recording, 'has_scaled') and callable(getattr(recording, 'has_scaled')):
scalable = recording.has_scaled()
else:
scalable = False
traces = recording.get_traces(segment_index=0, return_scaled=scalable)
times = sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id)
snippets = extract_snippets(traces=traces, times=times, snippet_len=(20, 20))
waveform = np.mean(snippets, axis=0).T.astype(np.float32)
Expand Down
98 changes: 81 additions & 17 deletions sortingview/SpikeSortingView/prepare_spikesortingview_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,42 @@ def prepare_spikesortingview_data(
num_frames = recording.get_num_frames()
num_frames_per_segment = math.ceil(segment_duration_sec * sampling_frequency)
num_segments = math.ceil(num_frames / num_frames_per_segment)
if hasattr(recording, "has_scaleable_traces") and callable(getattr(recording, "has_scaleable_traces")):
scalable = recording.has_scaleable_traces()
elif hasattr(recording, "has_scaled") and callable(getattr(recording, "has_scaled")):
scalable = recording.has_scaled()
else:
scalable = False

with kcl.TemporaryDirectory() as tmpdir:
output_file_name = tmpdir + "/spikesortingview.h5"
with h5py.File(output_file_name, "w") as f:
f.create_dataset("unit_ids", data=unit_ids)
f.create_dataset("sampling_frequency", data=np.array([sampling_frequency]).astype(np.float32))
f.create_dataset(
"sampling_frequency",
data=np.array([sampling_frequency]).astype(np.float32),
)
f.create_dataset("channel_ids", data=channel_ids)
f.create_dataset("num_frames", data=np.array([num_frames]).astype(int_type))
channel_locations = recording.get_channel_locations()
f.create_dataset("channel_locations", data=np.array(channel_locations))
f.create_dataset("num_segments", data=np.array([num_segments]).astype(np.int32))
f.create_dataset("num_frames_per_segment", data=np.array([num_frames_per_segment]).astype(np.int32))
f.create_dataset("snippet_len", data=np.array([snippet_len[0], snippet_len[1]]).astype(np.int32))
f.create_dataset("max_num_snippets_per_segment", data=np.array([max_num_snippets_per_segment]).astype(np.int32))
f.create_dataset("channel_neighborhood_size", data=np.array([channel_neighborhood_size]).astype(np.int32))
f.create_dataset(
"num_frames_per_segment",
data=np.array([num_frames_per_segment]).astype(np.int32),
)
f.create_dataset(
"snippet_len",
data=np.array([snippet_len[0], snippet_len[1]]).astype(np.int32),
)
f.create_dataset(
"max_num_snippets_per_segment",
data=np.array([max_num_snippets_per_segment]).astype(np.int32),
)
f.create_dataset(
"channel_neighborhood_size",
data=np.array([channel_neighborhood_size]).astype(np.int32),
)

# first get peak channels and channel neighborhoods
unit_peak_channel_ids = {}
Expand All @@ -68,19 +89,30 @@ def prepare_spikesortingview_data(
end_frame = min(start_frame + num_frames_per_segment, num_frames)
start_frame_with_padding = max(start_frame - snippet_len[0], 0)
end_frame_with_padding = min(end_frame + snippet_len[1], num_frames)
traces_with_padding = recording.get_traces(start_frame=start_frame_with_padding, end_frame=end_frame_with_padding)
traces_with_padding = recording.get_traces(
start_frame=start_frame_with_padding,
end_frame=end_frame_with_padding,
return_scaled=scalable,
)
assert isinstance(traces_with_padding, np.ndarray)
for unit_id in unit_ids:
if str(unit_id) not in unit_peak_channel_ids:
spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=start_frame, end_frame=end_frame)
spike_train = sorting.get_unit_spike_train(
unit_id=unit_id,
start_frame=start_frame,
end_frame=end_frame,
)
assert isinstance(spike_train, np.ndarray)
if len(spike_train) > 0:
values = traces_with_padding[spike_train - start_frame_with_padding, :].astype(np.int32)
avg_value = np.mean(values, axis=0)
peak_channel_ind = np.argmax(np.abs(avg_value))
peak_channel_id = channel_ids[peak_channel_ind]
channel_neighborhood = get_channel_neighborhood(
channel_ids=channel_ids, channel_locations=channel_locations, peak_channel_id=peak_channel_id, channel_neighborhood_size=channel_neighborhood_size
channel_ids=channel_ids,
channel_locations=channel_locations,
peak_channel_id=peak_channel_id,
channel_neighborhood_size=channel_neighborhood_size,
)
if len(spike_train) >= 10:
unit_peak_channel_ids[str(unit_id)] = peak_channel_id
Expand All @@ -94,17 +126,30 @@ def prepare_spikesortingview_data(
if peak_channel_id is None:
raise Exception(f"Peak channel not found for unit {unit_id}. This is probably because no spikes were found in any segment for this unit.")
channel_neighborhood = unit_channel_neighborhoods[str(unit_id)]
f.create_dataset(f"unit/{unit_id}/peak_channel_id", data=np.array([peak_channel_id]).astype(np.int32))
f.create_dataset(f"unit/{unit_id}/channel_neighborhood", data=np.array(channel_neighborhood).astype(np.int32))
f.create_dataset(
f"unit/{unit_id}/peak_channel_id",
data=np.array([peak_channel_id]).astype(np.int32),
)
f.create_dataset(
f"unit/{unit_id}/channel_neighborhood",
data=np.array(channel_neighborhood).astype(np.int32),
)

for iseg in range(num_segments):
print(f"Segment {iseg} of {num_segments}")
start_frame = iseg * num_frames_per_segment
end_frame = min(start_frame + num_frames_per_segment, num_frames)
start_frame_with_padding = max(start_frame - snippet_len[0], 0)
end_frame_with_padding = min(end_frame + snippet_len[1], num_frames)
traces_with_padding = recording.get_traces(start_frame=start_frame_with_padding, end_frame=end_frame_with_padding)
traces_sample = traces_with_padding[start_frame - start_frame_with_padding : start_frame - start_frame_with_padding + int(sampling_frequency * 1), :]
traces_with_padding = recording.get_traces(
start_frame=start_frame_with_padding,
end_frame=end_frame_with_padding,
return_scaled=scalable,
)
traces_sample = traces_with_padding[
start_frame - start_frame_with_padding : start_frame - start_frame_with_padding + int(sampling_frequency * 1),
:,
]
f.create_dataset(f"segment/{iseg}/traces_sample", data=traces_sample)
all_subsampled_spike_trains = []
for unit_id in unit_ids:
Expand All @@ -119,18 +164,28 @@ def prepare_spikesortingview_data(
peak_channel_ind = channel_ids.tolist().index(peak_channel_id)
if len(spike_train) > 0:
spike_amplitudes = traces_with_padding[spike_train - start_frame_with_padding, peak_channel_ind]
f.create_dataset(f"segment/{iseg}/unit/{unit_id}/spike_amplitudes", data=spike_amplitudes)
f.create_dataset(
f"segment/{iseg}/unit/{unit_id}/spike_amplitudes",
data=spike_amplitudes,
)
else:
spike_amplitudes = np.array([], dtype=np.int32)
if max_num_snippets_per_segment is not None and len(spike_train) > max_num_snippets_per_segment:
subsampled_spike_train = subsample(spike_train, max_num_snippets_per_segment)
else:
subsampled_spike_train = spike_train
f.create_dataset(f"segment/{iseg}/unit/{unit_id}/subsampled_spike_train", data=subsampled_spike_train)
f.create_dataset(
f"segment/{iseg}/unit/{unit_id}/subsampled_spike_train",
data=subsampled_spike_train,
)
all_subsampled_spike_trains.append(subsampled_spike_train)
subsampled_spike_trains_concat = np.concatenate(all_subsampled_spike_trains)
# print('Extracting spike snippets')
spike_snippets_concat = extract_spike_snippets(traces=traces_with_padding, times=subsampled_spike_trains_concat - start_frame_with_padding, snippet_len=snippet_len)
spike_snippets_concat = extract_spike_snippets(
traces=traces_with_padding,
times=subsampled_spike_trains_concat - start_frame_with_padding,
snippet_len=snippet_len,
)
# print('Collecting spike snippets')
index = 0
for ii, unit_id in enumerate(unit_ids):
Expand All @@ -139,12 +194,21 @@ def prepare_spikesortingview_data(
num = len(all_subsampled_spike_trains[ii])
spike_snippets = spike_snippets_concat[index : index + num, :, channel_neighborhood_indices]
index = index + num
f.create_dataset(f"segment/{iseg}/unit/{unit_id}/subsampled_spike_snippets", data=spike_snippets)
f.create_dataset(
f"segment/{iseg}/unit/{unit_id}/subsampled_spike_snippets",
data=spike_snippets,
)
uri = kcl.store_file_local(output_file_name)
return uri


def get_channel_neighborhood(*, channel_ids: np.ndarray, channel_locations: np.ndarray, peak_channel_id: int, channel_neighborhood_size: int):
def get_channel_neighborhood(
*,
channel_ids: np.ndarray,
channel_locations: np.ndarray,
peak_channel_id: int,
channel_neighborhood_size: int,
):
channel_locations_by_id = {}
for ii, channel_id in enumerate(channel_ids):
channel_locations_by_id[channel_id] = channel_locations[ii]
Expand Down
2 changes: 1 addition & 1 deletion sortingview/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# This file was automatically generated by jinjaroot. Do not edit directly.
__version__ = '0.13.4'
__version__ = "0.13.4"
8 changes: 7 additions & 1 deletion tests/test_average_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ def extract_snippets(*, traces: np.ndarray, times: np.ndarray, snippet_len: Tupl


def compute_average_waveform(*, recording: si.BaseRecording, sorting: si.BaseSorting, unit_id: int):
traces = recording.get_traces()
if hasattr(recording, 'has_scaleable_traces') and callable(getattr(recording, 'has_scaleable_traces')):
scalable = recording.has_scaleable_traces()
elif hasattr(recording, 'has_scaled') and callable(getattr(recording, 'has_scaled')):
scalable = recording.has_scaled()
else:
scalable = False
traces = recording.get_traces(return_scaled=scalable)
times = sorting.get_unit_spike_train(unit_id=unit_id)
snippets = extract_snippets(traces=traces, times=times, snippet_len=(20, 20))
waveform = np.mean(snippets, axis=0)
Expand Down

0 comments on commit b05ecba

Please sign in to comment.