From fb0e193baf4c64943ad846718b8d2a58593a6d97 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 1 May 2024 12:39:29 +0200 Subject: [PATCH 1/3] Add consolidate script --- python/consolidate_datasets.py | 110 +++++++++ ...o_calculate_templates_from_dandisets.ipynb | 0 ...w_to_calculate_templates_from_dandisets.py | 208 ++++++++++++++++++ 3 files changed, 318 insertions(+) create mode 100644 python/consolidate_datasets.py rename {notebooks => python}/how_to_calculate_templates_from_dandisets.ipynb (100%) create mode 100644 python/how_to_calculate_templates_from_dandisets.py diff --git a/python/consolidate_datasets.py b/python/consolidate_datasets.py new file mode 100644 index 0000000..2e765b0 --- /dev/null +++ b/python/consolidate_datasets.py @@ -0,0 +1,110 @@ +import boto3 +import pandas as pd +import zarr +import numpy as np + +from spikeinterface.core import Templates + +REGION_NAME = 'us-east-2' +HYBRID_BUCKET = "spikeinterface-template-database" + + +def list_bucket_objects( + bucket : str, + boto_client : boto3.client, + prefix : str = "", + include_substrings : str | list[str] | None = None, + skip_substrings : str | list[str] | None = None +): + # get all objects for session from s3 + paginator = boto_client.get_paginator('list_objects_v2') + pages = paginator.paginate(Prefix=prefix, Bucket=bucket) + keys = [] + + if include_substrings is not None: + if isinstance(include_substrings, str): + include_substrings = [include_substrings] + if skip_substrings is not None: + if isinstance(skip_substrings, str): + skip_substrings = [skip_substrings] + + for page in pages: + for item in page.get('Contents', []): + key = item['Key'] + if include_substrings is None and skip_substrings is None: + keys.append(key) + else: + if skip_substrings is not None: + if any([s in key for s in skip_substrings]): + continue + if include_substrings is not None: + if all([s in key for s in include_substrings]): + keys.append(key) + return keys + + +def consolidate_datasets(): + ### Find datasets and create dataframe with consolidated data + bc = boto3.client('s3') + + # Each dataset is stored in a zarr folder, so we look for the .zattrs files + keys = list_bucket_objects(HYBRID_BUCKET, boto_client=bc, include_substrings=".zattrs") + datasets = [k.split("/")[0] for k in keys] + + templates_df = pd.DataFrame( + columns=["dataset", "template_index", "best_channel_id", "brain_area", "depth", "amplitude"] + ) + + # Loop over datasets and extract relevant information + for dataset in datasets: + print(f"Processing dataset {dataset}") + zarr_path = f"s3://{HYBRID_BUCKET}/{dataset}" + zarr_group = zarr.open_consolidated(zarr_path, storage_options=dict(anon=True)) + + templates = Templates.from_zarr_group(zarr_group) + + num_units = templates.num_units + dataset_list = [dataset] * num_units + template_idxs = np.arange(num_units) + best_channel_idxs = zarr_group.get("best_channels", None) + brain_areas = zarr_group.get("brain_area", None) + channel_depths = templates.get_channel_locations()[:, 1] + + depths = np.zeros(num_units) + amps = np.zeros(num_units) + + if best_channels is not None: + best_channels = best_channels[:] + for i, best_channel_idx in enumerate(best_channels): + depths[i] = channel_depths[best_channel_idx] + amps[i] = np.ptp(templates.templates_array[i, :, best_channel_idx]) + else: + depths = np.nan + amps = np.nan + best_channels = ["unknwown"] * num_units + if brain_areas is not None: + brain_areas = brain_areas[:] + else: + brain_areas = ["unknwown"] * num_units + new_entry = pd.DataFrame( + data={ + "dataset": dataset_list, + "template_index": template_idxs, + "best_channel_id": best_channels, + "brain_area": brain_areas, + "depth": depths, + "amplitude": amps + } + ) + templates_df = pd.concat( + [templates_df, new_entry] + ) + + templates_df.to_csv("templates.csv", index=False) + + # Upload to S3 + bc.upload_file("templates.csv", HYBRID_BUCKET, "templates.csv") + + +if __name__ == "__main__": + consolidate_datasets() \ No newline at end of file diff --git a/notebooks/how_to_calculate_templates_from_dandisets.ipynb b/python/how_to_calculate_templates_from_dandisets.ipynb similarity index 100% rename from notebooks/how_to_calculate_templates_from_dandisets.ipynb rename to python/how_to_calculate_templates_from_dandisets.ipynb diff --git a/python/how_to_calculate_templates_from_dandisets.py b/python/how_to_calculate_templates_from_dandisets.py new file mode 100644 index 0000000..e40ddfe --- /dev/null +++ b/python/how_to_calculate_templates_from_dandisets.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python +# coding: utf-8 + +# # How to estimate templates from Dandisets +# The purpose of this draft notebook is to showcase how templates can be calculated by means of the `SortingAnalyzer` object. +# + +# In[1]: + + +from dandi.dandiapi import DandiAPIClient +from spikeinterface.extractors import NwbRecordingExtractor, IblSortingExtractor + +client = DandiAPIClient.for_dandi_instance("dandi") + +# We specifiy a dataset by is dandiset_id and its asset path +dandiset_id = "000409" +dandiset = client.get_dandiset(dandiset_id) + +asset_path = "sub-KS042/sub-KS042_ses-8c552ddc-813e-4035-81cc-3971b57efe65_behavior+ecephys+image.nwb" +recording_asset = dandiset.get_asset_by_path(path=asset_path) +url = recording_asset.get_content_url(follow_redirects=True, strip_query=True) +file_path = url + + +# Note that this ElectricalSeries corresponds to the data from probe 00 +electrical_series_path = "acquisition/ElectricalSeriesAp00" +recording = NwbRecordingExtractor(file_path=file_path, stream_mode="remfile", electrical_series_path=electrical_series_path) +session_id = recording._file["general"]["session_id"][()].decode() +eid = session_id.split("-chunking")[0] # eid : experiment id + + +# We use the sorting extractor from the IBL spike sorting pipeline that matches with eid +from one.api import ONE +ONE.setup(base_url='https://openalyx.internationalbrainlab.org', silent=True) +one_instance = ONE(password='international') + + +# Then we match the available probes with the probe number in the electrical series path +pids, probes = one_instance.eid2pid(eid) +probe_number = electrical_series_path.split("Ap")[-1] + +sorting_pid = None +for pid, probe in zip(pids, probes): + probe_number_in_pid = probe[-2:] + if probe_number_in_pid == probe_number: + sorting_pid = pid + break + + +sorting = IblSortingExtractor(pid=sorting_pid, one=one_instance, good_clusters_only=True) + + +# We now have our sorting and recording objects. We perform some preprocessing on our recording and slice ouf objects so we only estimate templates from the last minutes of the data. + +# In[2]: + + +from spikeinterface.preprocessing import astype, phase_shift, common_reference, highpass_filter + +pre_processed_recording = common_reference( + highpass_filter(phase_shift(astype(recording=recording, dtype="float32")), freq_min=1.0) +) + + +# take first and last minute +sampling_frequency_recording = pre_processed_recording.sampling_frequency +sorting_sampling_frequency = sorting.sampling_frequency +num_samples = pre_processed_recording.get_num_samples() + +# Take the last 10 minutes of the recording +minutes = 10 +seconds = minutes * 60 +samples_before_end = int(seconds * sampling_frequency_recording) + +start_frame_recording = num_samples - samples_before_end +end_frame_recording = num_samples + +recording_end = pre_processed_recording.frame_slice( + start_frame=start_frame_recording, + end_frame=end_frame_recording +) + +# num_samples = sorting.get_num_frames() +samples_before_end = int(seconds * sorting_sampling_frequency) +start_frame_sorting = num_samples - samples_before_end +end_frame_sorting = num_samples + +sorting_end = sorting.frame_slice( + start_frame=start_frame_sorting, + end_frame=end_frame_sorting +) + + +# We now use the `SortingAnalyzer` object to estimate templates. + +# In[3]: + + +from spikeinterface.core import create_sorting_analyzer + +analyzer = create_sorting_analyzer(sorting_end, recording_end, sparse=False, folder=f"analyzer_{eid}") + + +random_spike_parameters = { + "method": "all", +} + + +template_extension_parameters = { + "ms_before": 3.0, + "ms_after": 5.0, + "operators": ["average", "std"], +} + +extensions = { + "random_spikes": random_spike_parameters, + "templates": template_extension_parameters, +} + +analyzer.compute_several_extensions( + extensions=extensions, + n_jobs=-1, + progress_bar=True, +) + + +# In[4]: + + +templates_extension = analyzer.get_extension("templates") +template_object = templates_extension.get_data(outputs="Templates") + + +# That's it. We now have our data in a templates object (note the outputs keyword on `get_data`). As a visual test that the pipeline works we show how the best chanenl (defined as the one with the maximum peak to peak amplitude) and plot some unit's templates for that channel. + +# In[5]: + + +import numpy as np + + + +def find_channels_with_max_peak_to_peak_vectorized(templates_array): + """ + Find the channel indices with the maximum peak-to-peak value in each waveform template + using a vectorized operation for improved performance. + + Parameters: + templates_array (numpy.ndarray): The waveform templates_array, typically a 3D array (units x time x channels). + + Returns: + numpy.ndarray: An array of indices of the channel with the maximum peak-to-peak value for each unit. + """ + # Compute the peak-to-peak values along the time axis (axis=1) for each channel of each unit + peak_to_peak_values = np.ptp(templates_array, axis=1) + + # Find the indices of the channel with the maximum peak-to-peak value for each unit + best_channels = np.argmax(peak_to_peak_values, axis=1) + + return best_channels + + + + +# In[6]: + + +import matplotlib.pyplot as plt + +# Adjust global font size +plt.rcParams.update({"font.size": 18}) + +unit_ids = template_object.unit_ids +best_channels = find_channels_with_max_peak_to_peak_vectorized(template_object.templates_array) + + +num_columns = 3 +num_rows = 3 + +fig, axs = plt.subplots(num_rows, num_columns, figsize=(15, 20), sharex=True, sharey=True) + +center = template_object.nbefore + +for unit_index, unit_id in enumerate(unit_ids[: num_columns * num_rows]): + row, col = divmod(unit_index, num_columns) + ax = axs[row, col] + best_channel = best_channels[unit_index] + + ax.plot(template_object.templates_array[unit_index, :, best_channel], linewidth=3, label="best channel", color="black") + + ax.axvline(center, linestyle="--", color="red", linewidth=1) + ax.axhline(0, linestyle="--", color="gray", linewidth=1) + ax.set_title(f"Unit {unit_id}") + + # Hide all spines and ticks + ax.tick_params(axis="both", which="both", length=0) + + # Hide all spines + for spine in ax.spines.values(): + spine.set_visible(False) + +plt.tight_layout() + +# Create the legend with specified colors +handles, labels = axs[0, 0].get_legend_handles_labels() +fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False) + From 7a90dfa1bd033ba789a2c0c5044272f90164be50 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 1 May 2024 12:43:05 +0200 Subject: [PATCH 2/3] skip test templates --- python/consolidate_datasets.py | 36 ++++++++++++++++------------------ 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/python/consolidate_datasets.py b/python/consolidate_datasets.py index 2e765b0..7c0d265 100644 --- a/python/consolidate_datasets.py +++ b/python/consolidate_datasets.py @@ -5,19 +5,19 @@ from spikeinterface.core import Templates -REGION_NAME = 'us-east-2' +REGION_NAME = "us-east-2" HYBRID_BUCKET = "spikeinterface-template-database" def list_bucket_objects( - bucket : str, - boto_client : boto3.client, - prefix : str = "", - include_substrings : str | list[str] | None = None, - skip_substrings : str | list[str] | None = None + bucket: str, + boto_client: boto3.client, + prefix: str = "", + include_substrings: str | list[str] | None = None, + skip_substrings: str | list[str] | None = None, ): # get all objects for session from s3 - paginator = boto_client.get_paginator('list_objects_v2') + paginator = boto_client.get_paginator("list_objects_v2") pages = paginator.paginate(Prefix=prefix, Bucket=bucket) keys = [] @@ -29,8 +29,8 @@ def list_bucket_objects( skip_substrings = [skip_substrings] for page in pages: - for item in page.get('Contents', []): - key = item['Key'] + for item in page.get("Contents", []): + key = item["Key"] if include_substrings is None and skip_substrings is None: keys.append(key) else: @@ -45,7 +45,7 @@ def list_bucket_objects( def consolidate_datasets(): ### Find datasets and create dataframe with consolidated data - bc = boto3.client('s3') + bc = boto3.client("s3") # Each dataset is stored in a zarr folder, so we look for the .zattrs files keys = list_bucket_objects(HYBRID_BUCKET, boto_client=bc, include_substrings=".zattrs") @@ -88,17 +88,15 @@ def consolidate_datasets(): brain_areas = ["unknwown"] * num_units new_entry = pd.DataFrame( data={ - "dataset": dataset_list, - "template_index": template_idxs, - "best_channel_id": best_channels, - "brain_area": brain_areas, + "dataset": dataset_list, + "template_index": template_idxs, + "best_channel_id": best_channels, + "brain_area": brain_areas, "depth": depths, - "amplitude": amps + "amplitude": amps, } ) - templates_df = pd.concat( - [templates_df, new_entry] - ) + templates_df = pd.concat([templates_df, new_entry]) templates_df.to_csv("templates.csv", index=False) @@ -107,4 +105,4 @@ def consolidate_datasets(): if __name__ == "__main__": - consolidate_datasets() \ No newline at end of file + consolidate_datasets() From 26b15e23f037d3ad85c0d0524a43c27f5c1626ff Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 1 May 2024 12:46:21 +0200 Subject: [PATCH 3/3] skip test templates2 --- python/consolidate_datasets.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/consolidate_datasets.py b/python/consolidate_datasets.py index 7c0d265..dfafe45 100644 --- a/python/consolidate_datasets.py +++ b/python/consolidate_datasets.py @@ -5,8 +5,8 @@ from spikeinterface.core import Templates -REGION_NAME = "us-east-2" HYBRID_BUCKET = "spikeinterface-template-database" +SKIP_TEST = True def list_bucket_objects( @@ -48,7 +48,10 @@ def consolidate_datasets(): bc = boto3.client("s3") # Each dataset is stored in a zarr folder, so we look for the .zattrs files - keys = list_bucket_objects(HYBRID_BUCKET, boto_client=bc, include_substrings=".zattrs") + exclude_substrings = ["test_templates"] if SKIP_TEST else None + keys = list_bucket_objects( + HYBRID_BUCKET, boto_client=bc, include_substrings=".zattrs", exclude_substrings=exclude_substrings + ) datasets = [k.split("/")[0] for k in keys] templates_df = pd.DataFrame( @@ -73,9 +76,9 @@ def consolidate_datasets(): depths = np.zeros(num_units) amps = np.zeros(num_units) - if best_channels is not None: - best_channels = best_channels[:] - for i, best_channel_idx in enumerate(best_channels): + if best_channel_idxs is not None: + best_channel_idxs = best_channel_idxs[:] + for i, best_channel_idx in enumerate(best_channel_idxs): depths[i] = channel_depths[best_channel_idx] amps[i] = np.ptp(templates.templates_array[i, :, best_channel_idx]) else: