diff --git a/python/consolidate_datasets.py b/python/consolidate_datasets.py new file mode 100644 index 0000000..dfafe45 --- /dev/null +++ b/python/consolidate_datasets.py @@ -0,0 +1,111 @@ +import boto3 +import pandas as pd +import zarr +import numpy as np + +from spikeinterface.core import Templates + +HYBRID_BUCKET = "spikeinterface-template-database" +SKIP_TEST = True + + +def list_bucket_objects( + bucket: str, + boto_client: boto3.client, + prefix: str = "", + include_substrings: str | list[str] | None = None, + skip_substrings: str | list[str] | None = None, +): + # get all objects for session from s3 + paginator = boto_client.get_paginator("list_objects_v2") + pages = paginator.paginate(Prefix=prefix, Bucket=bucket) + keys = [] + + if include_substrings is not None: + if isinstance(include_substrings, str): + include_substrings = [include_substrings] + if skip_substrings is not None: + if isinstance(skip_substrings, str): + skip_substrings = [skip_substrings] + + for page in pages: + for item in page.get("Contents", []): + key = item["Key"] + if include_substrings is None and skip_substrings is None: + keys.append(key) + else: + if skip_substrings is not None: + if any([s in key for s in skip_substrings]): + continue + if include_substrings is not None: + if all([s in key for s in include_substrings]): + keys.append(key) + return keys + + +def consolidate_datasets(): + ### Find datasets and create dataframe with consolidated data + bc = boto3.client("s3") + + # Each dataset is stored in a zarr folder, so we look for the .zattrs files + exclude_substrings = ["test_templates"] if SKIP_TEST else None + keys = list_bucket_objects( + HYBRID_BUCKET, boto_client=bc, include_substrings=".zattrs", exclude_substrings=exclude_substrings + ) + datasets = [k.split("/")[0] for k in keys] + + templates_df = pd.DataFrame( + columns=["dataset", "template_index", "best_channel_id", "brain_area", "depth", "amplitude"] + ) + + # Loop over datasets and extract relevant information + for dataset in datasets: + print(f"Processing dataset {dataset}") + zarr_path = f"s3://{HYBRID_BUCKET}/{dataset}" + zarr_group = zarr.open_consolidated(zarr_path, storage_options=dict(anon=True)) + + templates = Templates.from_zarr_group(zarr_group) + + num_units = templates.num_units + dataset_list = [dataset] * num_units + template_idxs = np.arange(num_units) + best_channel_idxs = zarr_group.get("best_channels", None) + brain_areas = zarr_group.get("brain_area", None) + channel_depths = templates.get_channel_locations()[:, 1] + + depths = np.zeros(num_units) + amps = np.zeros(num_units) + + if best_channel_idxs is not None: + best_channel_idxs = best_channel_idxs[:] + for i, best_channel_idx in enumerate(best_channel_idxs): + depths[i] = channel_depths[best_channel_idx] + amps[i] = np.ptp(templates.templates_array[i, :, best_channel_idx]) + else: + depths = np.nan + amps = np.nan + best_channels = ["unknwown"] * num_units + if brain_areas is not None: + brain_areas = brain_areas[:] + else: + brain_areas = ["unknwown"] * num_units + new_entry = pd.DataFrame( + data={ + "dataset": dataset_list, + "template_index": template_idxs, + "best_channel_id": best_channels, + "brain_area": brain_areas, + "depth": depths, + "amplitude": amps, + } + ) + templates_df = pd.concat([templates_df, new_entry]) + + templates_df.to_csv("templates.csv", index=False) + + # Upload to S3 + bc.upload_file("templates.csv", HYBRID_BUCKET, "templates.csv") + + +if __name__ == "__main__": + consolidate_datasets() diff --git a/notebooks/how_to_calculate_templates_from_dandisets.ipynb b/python/how_to_calculate_templates_from_dandisets.ipynb similarity index 100% rename from notebooks/how_to_calculate_templates_from_dandisets.ipynb rename to python/how_to_calculate_templates_from_dandisets.ipynb diff --git a/python/how_to_calculate_templates_from_dandisets.py b/python/how_to_calculate_templates_from_dandisets.py new file mode 100644 index 0000000..e40ddfe --- /dev/null +++ b/python/how_to_calculate_templates_from_dandisets.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python +# coding: utf-8 + +# # How to estimate templates from Dandisets +# The purpose of this draft notebook is to showcase how templates can be calculated by means of the `SortingAnalyzer` object. +# + +# In[1]: + + +from dandi.dandiapi import DandiAPIClient +from spikeinterface.extractors import NwbRecordingExtractor, IblSortingExtractor + +client = DandiAPIClient.for_dandi_instance("dandi") + +# We specifiy a dataset by is dandiset_id and its asset path +dandiset_id = "000409" +dandiset = client.get_dandiset(dandiset_id) + +asset_path = "sub-KS042/sub-KS042_ses-8c552ddc-813e-4035-81cc-3971b57efe65_behavior+ecephys+image.nwb" +recording_asset = dandiset.get_asset_by_path(path=asset_path) +url = recording_asset.get_content_url(follow_redirects=True, strip_query=True) +file_path = url + + +# Note that this ElectricalSeries corresponds to the data from probe 00 +electrical_series_path = "acquisition/ElectricalSeriesAp00" +recording = NwbRecordingExtractor(file_path=file_path, stream_mode="remfile", electrical_series_path=electrical_series_path) +session_id = recording._file["general"]["session_id"][()].decode() +eid = session_id.split("-chunking")[0] # eid : experiment id + + +# We use the sorting extractor from the IBL spike sorting pipeline that matches with eid +from one.api import ONE +ONE.setup(base_url='https://openalyx.internationalbrainlab.org', silent=True) +one_instance = ONE(password='international') + + +# Then we match the available probes with the probe number in the electrical series path +pids, probes = one_instance.eid2pid(eid) +probe_number = electrical_series_path.split("Ap")[-1] + +sorting_pid = None +for pid, probe in zip(pids, probes): + probe_number_in_pid = probe[-2:] + if probe_number_in_pid == probe_number: + sorting_pid = pid + break + + +sorting = IblSortingExtractor(pid=sorting_pid, one=one_instance, good_clusters_only=True) + + +# We now have our sorting and recording objects. We perform some preprocessing on our recording and slice ouf objects so we only estimate templates from the last minutes of the data. + +# In[2]: + + +from spikeinterface.preprocessing import astype, phase_shift, common_reference, highpass_filter + +pre_processed_recording = common_reference( + highpass_filter(phase_shift(astype(recording=recording, dtype="float32")), freq_min=1.0) +) + + +# take first and last minute +sampling_frequency_recording = pre_processed_recording.sampling_frequency +sorting_sampling_frequency = sorting.sampling_frequency +num_samples = pre_processed_recording.get_num_samples() + +# Take the last 10 minutes of the recording +minutes = 10 +seconds = minutes * 60 +samples_before_end = int(seconds * sampling_frequency_recording) + +start_frame_recording = num_samples - samples_before_end +end_frame_recording = num_samples + +recording_end = pre_processed_recording.frame_slice( + start_frame=start_frame_recording, + end_frame=end_frame_recording +) + +# num_samples = sorting.get_num_frames() +samples_before_end = int(seconds * sorting_sampling_frequency) +start_frame_sorting = num_samples - samples_before_end +end_frame_sorting = num_samples + +sorting_end = sorting.frame_slice( + start_frame=start_frame_sorting, + end_frame=end_frame_sorting +) + + +# We now use the `SortingAnalyzer` object to estimate templates. + +# In[3]: + + +from spikeinterface.core import create_sorting_analyzer + +analyzer = create_sorting_analyzer(sorting_end, recording_end, sparse=False, folder=f"analyzer_{eid}") + + +random_spike_parameters = { + "method": "all", +} + + +template_extension_parameters = { + "ms_before": 3.0, + "ms_after": 5.0, + "operators": ["average", "std"], +} + +extensions = { + "random_spikes": random_spike_parameters, + "templates": template_extension_parameters, +} + +analyzer.compute_several_extensions( + extensions=extensions, + n_jobs=-1, + progress_bar=True, +) + + +# In[4]: + + +templates_extension = analyzer.get_extension("templates") +template_object = templates_extension.get_data(outputs="Templates") + + +# That's it. We now have our data in a templates object (note the outputs keyword on `get_data`). As a visual test that the pipeline works we show how the best chanenl (defined as the one with the maximum peak to peak amplitude) and plot some unit's templates for that channel. + +# In[5]: + + +import numpy as np + + + +def find_channels_with_max_peak_to_peak_vectorized(templates_array): + """ + Find the channel indices with the maximum peak-to-peak value in each waveform template + using a vectorized operation for improved performance. + + Parameters: + templates_array (numpy.ndarray): The waveform templates_array, typically a 3D array (units x time x channels). + + Returns: + numpy.ndarray: An array of indices of the channel with the maximum peak-to-peak value for each unit. + """ + # Compute the peak-to-peak values along the time axis (axis=1) for each channel of each unit + peak_to_peak_values = np.ptp(templates_array, axis=1) + + # Find the indices of the channel with the maximum peak-to-peak value for each unit + best_channels = np.argmax(peak_to_peak_values, axis=1) + + return best_channels + + + + +# In[6]: + + +import matplotlib.pyplot as plt + +# Adjust global font size +plt.rcParams.update({"font.size": 18}) + +unit_ids = template_object.unit_ids +best_channels = find_channels_with_max_peak_to_peak_vectorized(template_object.templates_array) + + +num_columns = 3 +num_rows = 3 + +fig, axs = plt.subplots(num_rows, num_columns, figsize=(15, 20), sharex=True, sharey=True) + +center = template_object.nbefore + +for unit_index, unit_id in enumerate(unit_ids[: num_columns * num_rows]): + row, col = divmod(unit_index, num_columns) + ax = axs[row, col] + best_channel = best_channels[unit_index] + + ax.plot(template_object.templates_array[unit_index, :, best_channel], linewidth=3, label="best channel", color="black") + + ax.axvline(center, linestyle="--", color="red", linewidth=1) + ax.axhline(0, linestyle="--", color="gray", linewidth=1) + ax.set_title(f"Unit {unit_id}") + + # Hide all spines and ticks + ax.tick_params(axis="both", which="both", length=0) + + # Hide all spines + for spine in ax.spines.values(): + spine.set_visible(False) + +plt.tight_layout() + +# Create the legend with specified colors +handles, labels = axs[0, 0].get_legend_handles_labels() +fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False) +