Skip to content

Commit

Permalink
Upload README and added smoothing for NP-Ultra templates
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 28, 2024
1 parent 670cc2a commit e629808
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 30 deletions.
51 changes: 45 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,49 @@
# hybrid_template_library
Library of templates to create hybrid data for spike sorting benchmarks
# SpikeInterface Hybrid Template Library

[click here for access to the library](https://spikeinterface.github.io/hybrid_template_library/)

This repo contains a set of tools to construct and interact with a library of hybrid templates for spike sorting benchmarks.

## Testing locally
The library is made of several datasets stored a zarr file and can be accessed through the spikeinterface library.
The library is also accessible through a web-app that allows users to browse the templates and download them for
use in their spike sorting benchmarks.


## Template sources

The following datasets are available in the library:

- [IBL](https://dandiarchive.org/dandiset/000409?search=IBL&pos=3): Neuropixels 1.0 templates from the IBL Brain Wide Map dataset
- [Steinmetz and Ye. 2022](https://doi.org/10.6084/m9.figshare.19493588.v2): Neuropixels Ultra templates from Steinmetz and Ye. 2022

The templates have been processed and stored with the `python` scripts in the `python/scripts` folder and are stored in `zarr`
format in the `s3://spikeinterface-template-library` bucket hosted on `AWS S3` by [CatalystNeuro](https://www.catalystneuro.com/).


## Accessing the data through `SpikeInterface`

The library can be accessed through the `spikeinterface` library using the `generation` module.
The following code shows how to access the library to fetch a dataframe with the available templates
and download the templates corresponing to a specific user query:

```python
import spikeinterface.generation as sgen

templates_info = sgen.fetch_templates_database_info()

# select templates with amplitude between 200 and 250uV
templates_info_selected = templates_info.query('amplitude_uv > 200 and amplitude_uv < 250')
templates_selected = sgen.sgen.query_templates_from_database(templates_info_selected)
```

For a more comprehensive example on how to construct hybrid recordings from the template library and run spike sorting
benchmarks, please refer to the SpikeInterface tutorial on [Hybrid recordings](https://spikeinterface.readthedocs.io/en/latest/how_to/benchmark_with_hybrid_recordings.html).

## Live Web-App

The template library can be browsed through a web-app (source code included in this repo). The web-app is hosted on github pages and can be accessed through the following link: [https://spikeinterface.github.io/hybrid_template_library/](https://spikeinterface.github.io/hybrid_template_library/)


### Testing locally

How to run a python server for testing zarr access

Expand Down Expand Up @@ -38,10 +77,10 @@ python -c "from http.server import HTTPServer, SimpleHTTPRequestHandler; import
```


Then you run the npm script to start the server and open the browser
Then you run the `npm` script to start the server and open the browser

```bash
export TEST_URL="http://localhost:8000/zarr_store.zarr"
export TEST_URL="http://localhost:8000/test_zarr.zarr"
npm run start
```

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ classifiers = [

dependencies = [
"spikeinterface >= 0.101",
"MEArec",
"tqdm",
"pynwb>=2.8",
"remfile==0.1",
Expand Down
8 changes: 6 additions & 2 deletions python/consolidate_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ def consolidate_datasets(dry_run: bool = False, verbose: bool = False):

depth_best_channel = channel_depths[best_channel_indices]
peak_to_peak_best_channel = zarr_group["peak_to_peak"][template_indices, best_channel_indices]
noise_best_channel = zarr_group["channel_noise_levels"][best_channel_indices]
signal_to_noise_ratio_best_channel = peak_to_peak_best_channel / noise_best_channel
if "channel_noise_levels" not in zarr_group:
noise_best_channel = np.nan * np.zeros(num_units)
signal_to_noise_ratio_best_channel = np.nan * np.zeros(num_units)
else:
noise_best_channel = zarr_group["channel_noise_levels"][best_channel_indices]
signal_to_noise_ratio_best_channel = peak_to_peak_best_channel / noise_best_channel

new_entry = pd.DataFrame(
{
Expand Down
26 changes: 18 additions & 8 deletions python/delete_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def delete_templates_too_few_spikes(min_spikes=50, dry_run=False, verbose=True):
"""
This function will delete templates associated to spike trains with too few spikes.
The initial database was in fact created without a minimum number of spikes per unit,
The initial database was in fact created without a minimum number of spikes per unit,
so some units have very few spikes and possibly a noisy template.
"""
import spikeinterface.generation as sgen
Expand All @@ -49,7 +49,9 @@ def delete_templates_too_few_spikes(min_spikes=50, dry_run=False, verbose=True):

if len(templates_to_remove) > 0:
if verbose:
print(f"Removing {len(templates_to_remove)}/{len(templates_info)} templates with less than {min_spikes} spikes")
print(
f"Removing {len(templates_to_remove)}/{len(templates_info)} templates with less than {min_spikes} spikes"
)
datasets = np.unique(templates_to_remove["dataset"])

for d_i, dataset in enumerate(datasets):
Expand All @@ -59,7 +61,7 @@ def delete_templates_too_few_spikes(min_spikes=50, dry_run=False, verbose=True):
template_indices_to_remove = templates_in_dataset.template_index.values
s3_path = templates_in_dataset.dataset_path.values[0]

# to filter from the zarr dataset:
# to filter from the zarr dataset:
datasets_to_filter = [
"templates_array",
"best_channel_index",
Expand All @@ -84,26 +86,32 @@ def delete_templates_too_few_spikes(min_spikes=50, dry_run=False, verbose=True):
print(f"\tRemoving {n_original_units - n_units_to_keep} templates from {n_original_units}")
for dset in datasets_to_filter:
dataset_original = zarr_root[dset]
if len(dataset_original) == n_units_to_keep:
print(f"\t\tDataset: {dset} - shape: {dataset_original.shape} - already updated")
continue
dataset_filtered = dataset_original[unit_indices_to_keep]
if not dry_run:
if verbose:
print(f"\t\tUpdating: {dset} - shape: {dataset_filtered.shape}")
if dataset_filtered.dtype.kind == "O":
dataset_filtered = dataset_filtered.astype(str)
zarr_root[dset] = dataset_filtered
else:
if verbose:
print(f"\t\tDry run: {dset} - shape: {dataset_filtered.shape}")
if not dry_run:
zarr.consolidate_metadata(zarr_root.store)



def delete_templates_with_num_samples(dry_run=False):
"""
This function will delete templates with number of samples,
This function will delete templates with number of samples,
which were not corrected for in the initial database.
"""
bucket = "spikeinterface-template-database"
boto_client = boto3.client("s3")
verbose = True
verbose = True

templates_to_erase_from_bucket = [
"000409_sub-KS084_ses-1b715600-0cbc-442c-bd00-5b0ac2865de1_behavior+ecephys+image_bbe6ebc1-d32f-42dd-a89c-211226737deb.zarr",
"000409_sub-KS086_ses-e45481fa-be22-4365-972c-e7404ed8ab5a_behavior+ecephys+image_f2a098e7-a67e-4125-92d8-36fc6b606c45.zarr",
Expand All @@ -118,7 +126,9 @@ def delete_templates_with_num_samples(dry_run=False):
"000409_sub-KS096_ses-f819d499-8bf7-4da0-a431-15377a8319d5_behavior+ecephys+image_4ea45238-55b1-4d54-ba92-efa47feb9f57.zarr",
]
existing_templates = list_zarr_directories(bucket, boto_client=boto_client)
templates_to_erase_from_bucket = [template for template in templates_to_erase_from_bucket if template in existing_templates]
templates_to_erase_from_bucket = [
template for template in templates_to_erase_from_bucket if template in existing_templates
]
if dry_run:
if verbose:
print(f"Would erase {len(templates_to_erase_from_bucket)} templates from bucket: {bucket}")
Expand Down
17 changes: 12 additions & 5 deletions python/upload_ibl_templates.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""
This script constructs and uploads the templates from the International Brain Laboratory (IBL) datasets
available from DANDI (https://dandiarchive.org/dandiset/000409?search=IBL&pos=3).
The templates are constructed using the spikeinterface package and are saved to a
Zarr file. The Zarr file is then uploaded to an S3 bucket hosted by CatalystNeuro for storage and sharing.
The s3 bucket "spikeinterface-template-database" is used by the SpikeInterface hybrid framework to construct hybrid
recordings.
Templates are extracted by combining the raw data from the NWB files on DANDI with the spike trains form
the Alyx ONE database. Only the units that passed the IBL quality control are used.
To minimize the amount of drift in the templates, only the last 30 minutes of the recording are used.
The raw recordings are pre-processed with a high-pass filter and a common median reference prior to
template extraction. Units with less than 50 spikes are excluded from the template database.
Once the templates are constructed they are saved to a Zarr file which is then uploaded to
"spikeinterface-template-database" bucket (hosted by CatalystNeuro).
"""

from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -294,7 +299,9 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):

if upload_data:
# Create a S3 file system object with explicit credentials
s3_kwargs = dict(anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs)
s3_kwargs = dict(
anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs
)
s3 = s3fs.S3FileSystem(**s3_kwargs)

# Specify the S3 bucket and path
Expand Down
68 changes: 59 additions & 9 deletions python/upload_npultra_templates.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""
This script constructs and uploads the templates from the Neuropixels Ultra dataset
form Steinmetz and Ye, 2022. The dataset is hosted on Figshare at https://doi.org/10.6084/m9.figshare.19493588.v2
The templates and relevant metadata are packaged into a `spikeinterface.Templates` and saved to a
Zarr file. The Zarr file is then uploaded to an S3 bucket hosted by CatalystNeuro for storage and sharing.
The s3 bucket "spikeinterface-template-database" is used by the SpikeInterface hybrid framework to construct hybrid
recordings.
Since the templates in the dataset have rather short cut outs, which might negatively interfere with
hybrid spike injections, the templates are padded and smoothed using the `MEArec` package so that they
end up having 240 samples (90 before, 150 after the peak).
Once the templates are constructed they are saved to a Zarr file which is then uploaded to
"spikeinterface-template-database" bucket (hosted by CatalystNeuro).
"""

from pathlib import Path

import numpy as np
Expand All @@ -18,7 +21,27 @@
import probeinterface as pi
import spikeinterface as si

from MEArec.tools import pad_templates, sigmoid


def smooth_edges(templates, pad_samples, smooth_percent=0.5, smooth_strength=1):
# smooth edges
sigmoid_samples = int(smooth_percent * pad_samples[0]) // 2 * 2
sigmoid_x = np.arange(-sigmoid_samples // 2, sigmoid_samples // 2)
b = smooth_strength
sig = sigmoid(sigmoid_x, b) + 0.5
window = np.ones(templates.shape[-1])
window[:sigmoid_samples] = sig
window[-sigmoid_samples:] = sig[::-1]

templates *= window
return templates


# parameters
min_spikes_per_unit = 50
target_nbefore = 90
target_nafter = 150
upload_data = False

npultra_templates_path = Path("/home/alessio/Documents/Data/Templates/NPUltraWaveforms/")
Expand Down Expand Up @@ -54,23 +77,50 @@
unit_ids = unit_ids[unit_ids_enough_spikes]
spikes_per_unit = spikes_per_unit[unit_ids_enough_spikes]

# sort the units by unit_id
# Sort the units by unit_id
sort_unit_indices = np.argsort(unit_ids)
unit_itd = unit_ids[sort_unit_indices]
unit_ids = unit_ids[sort_unit_indices].astype(int)
spikes_per_unit = spikes_per_unit[sort_unit_indices]
brain_area_acronym = brain_area_acronym[sort_unit_indices]

# Create Templates object
# Process the templates to make them smooth
nbefore = 40
sampling_frequency = 30000
num_samples = templates_array.shape[1]
nafter = num_samples - nbefore

pad_samples = [target_nbefore - nbefore, target_nafter - nafter]

# MEArec needs swap axes
print("Padding and smoothing templates")
templates_array_swap = templates_array.swapaxes(1, 2)
tmp_templates_file = "templates_padded.raw"
templates_padded_swap = pad_templates(
templates_array_swap,
pad_samples,
drifting=False,
dtype="float",
verbose=False,
n_jobs=-1,
tmp_file=tmp_templates_file,
parallel=True,
)
templates_padded = templates_padded_swap.swapaxes(1, 2)
Path(tmp_templates_file).unlink()

# smooth edges
templates_smoothed_swap = smooth_edges(templates_padded_swap, pad_samples)
templates_smoothed = templates_smoothed_swap.swapaxes(1, 2)

# Create Templates object
print("Creating Templates object")
templates_ultra = si.Templates(
templates_array=templates_array,
templates_array=templates_smoothed,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
unit_ids=unit_ids,
probe=probe,
is_scaled=True
is_scaled=True,
)

best_channel_index = si.get_template_extremum_channel(templates_ultra, mode="peak_to_peak", outputs="index")
Expand Down

0 comments on commit e629808

Please sign in to comment.