diff --git a/python/consolidate_datasets.py b/python/consolidate_datasets.py index fd74d70..ed2e94c 100644 --- a/python/consolidate_datasets.py +++ b/python/consolidate_datasets.py @@ -9,134 +9,132 @@ parser = ArgumentParser(description="Consolidate datasets from spikeinterface template database") parser.add_argument("--dry-run", action="store_true", help="Dry run (no upload)") -parser.add_argument("--no-skip-test", action="store_true", help="Skip test datasets") parser.add_argument("--bucket", type=str, help="S3 bucket name", default="spikeinterface-template-database") -def list_bucket_objects( - bucket: str, - boto_client: boto3.client, - prefix: str = "", - include_substrings: str | list[str] | None = None, - skip_substrings: str | list[str] | None = None, -): - # get all objects for session from s3 - paginator = boto_client.get_paginator("list_objects_v2") - pages = paginator.paginate(Prefix=prefix, Bucket=bucket) - keys = [] - - if include_substrings is not None: - if isinstance(include_substrings, str): - include_substrings = [include_substrings] - if skip_substrings is not None: - if isinstance(skip_substrings, str): - skip_substrings = [skip_substrings] - - for page in pages: - for item in page.get("Contents", []): - key = item["Key"] - if include_substrings is None and skip_substrings is None: - keys.append(key) - else: - if skip_substrings is not None: - if any([s in key for s in skip_substrings]): - continue - if include_substrings is not None: - if all([s in key for s in include_substrings]): - keys.append(key) - return keys - - -def consolidate_datasets( - dry_run: bool = False, skip_test_folder: bool = True, bucket="spikeinterface-template-database" -): - ### Find datasets and create dataframe with consolidated data - bc = boto3.client("s3") - - # Each dataset is stored in a zarr folder, so we look for the .zattrs files - skip_substrings = ["test_templates"] if skip_test_folder else None - keys = list_bucket_objects(bucket, boto_client=bc, include_substrings=".zattrs", skip_substrings=skip_substrings) - datasets = [k.split("/")[0] for k in keys] - print(f"Found {len(datasets)} datasets to consolidate\n") - - templates_df = None - - # Loop over datasets and extract relevant information - for dataset in datasets: + +def list_zarr_directories(bucket_name, boto_client=None): + """Lists top-level Zarr directory keys in an S3 bucket. + + Parameters + ---------- + bucket_name : str + The name of the S3 bucket to search. + boto_client : boto3.client, optional + An existing Boto3 S3 client. If not provided, a new client will be created. + + Returns + ------- + zarr_directories : list + A list of strings representing the full S3 keys (paths) of top-level Zarr directories + found in the bucket. + """ + + boto_client = boto_client or boto3.client('s3') + zarr_directories = set() + + paginator = boto_client.get_paginator('list_objects_v2') + for page in paginator.paginate(Bucket=bucket_name, Delimiter='/'): + for prefix in page.get('CommonPrefixes', []): + key = prefix['Prefix'] + if key.endswith('.zarr/'): + zarr_directories.add(key.rstrip('/')) + + return list(zarr_directories) + +def consolidate_datasets(dry_run: bool = False, verbose: bool = False): + """Consolidates data from Zarr datasets within an S3 bucket. + + Parameters + ---------- + dry_run : bool, optional + If True, do not upload the consolidated data to S3. Defaults to False. + verbose : bool, optional + If True, print additional information during processing. Defaults to False. + + Returns + ------- + pandas.DataFrame + A DataFrame containing the consolidated data from all Zarr datasets. + + Raises + ------ + FileNotFoundError + If no Zarr datasets are found in the specified bucket. + """ + + bucket="spikeinterface-template-database" + boto_client = boto3.client("s3") + + # Get list of Zarr directories, excluding test datasets + zarr_datasets = list_zarr_directories(bucket_name=bucket, boto_client=boto_client) + datasets_to_avoid = ["test_templates.zarr"] + zarr_datasets = [d for d in zarr_datasets if d not in datasets_to_avoid] + + if not zarr_datasets: + raise FileNotFoundError(f"No Zarr datasets found in bucket: {bucket}") + if verbose: + print(f"Found {len(zarr_datasets)} datasets to consolidate\n") + + # Initialize list to collect DataFrames for each dataset + all_dataframes = [] + + for dataset in zarr_datasets: print(f"Processing dataset {dataset}") zarr_path = f"s3://{bucket}/{dataset}" zarr_group = zarr.open_consolidated(zarr_path, storage_options=dict(anon=True)) - templates = Templates.from_zarr_group(zarr_group) + # Extract data efficiently using NumPy arrays num_units = templates.num_units - dataset_list = [dataset] * num_units - dataset_path = [zarr_path] * num_units - template_idxs = np.arange(num_units) - best_channel_idxs = zarr_group.get("best_channels", None) - brain_areas = zarr_group.get("brain_area", None) - peak_to_peaks = zarr_group.get("peak_to_peak", None) - spikes_per_units = zarr_group.get("spikes_per_unit", None) - - # TODO: get probe name from probe metadata - + probe_attributes = zarr_group["probe"]["annotations"].attrs.asdict() + template_indices = np.arange(num_units) + default_brain_area = ["unknown"] * num_units + brain_areas = zarr_group.get("brain_area", default_brain_area) channel_depths = templates.get_channel_locations()[:, 1] + spikes_per_unit = zarr_group["spikes_per_unit"][:] + best_channel_indices = zarr_group["best_channel_index"][:] - depths = np.zeros(num_units) - amps = np.zeros(num_units) - - if best_channel_idxs is not None: - best_channel_idxs = best_channel_idxs[:] - for i, best_channel_idx in enumerate(best_channel_idxs): - depths[i] = channel_depths[best_channel_idx] - if peak_to_peaks is None: - amps[i] = np.ptp(templates.templates_array[i, :, best_channel_idx]) - else: - amps[i] = peak_to_peaks[i, best_channel_idx] - else: - depths = np.nan - amps = np.nan - best_channel_idxs = [-1] * num_units - spikes_per_units = [-1] * num_units - if brain_areas is not None: - brain_areas = brain_areas[:] - else: - brain_areas = ["unknwown"] * num_units + depth_best_channel = channel_depths[best_channel_indices] + peak_to_peak_best_channel = zarr_group["peak_to_peak"][template_indices, best_channel_indices] + noise_best_channel = zarr_group["channel_noise_levels"][best_channel_indices] + signal_to_noise_ratio_best_channel = peak_to_peak_best_channel / noise_best_channel new_entry = pd.DataFrame( - data={ - "dataset": dataset_list, - "dataset_path": dataset_path, - "probe": ["Neuropixels1.0"] * num_units, - "template_index": template_idxs, - "best_channel_id": best_channel_idxs, + { + "probe": [probe_attributes["model_name"]] * num_units, + "probe_manufacturer": [probe_attributes["manufacturer"]] * num_units, "brain_area": brain_areas, - "depth_along_probe": depths, - "amplitude_uv": amps, - "spikes_per_unit": spikes_per_units, + "depth_along_probe": depth_best_channel, + "amplitude_uv": peak_to_peak_best_channel, + "noise_level_uv": noise_best_channel, + "signal_to_noise_ratio": signal_to_noise_ratio_best_channel, + "template_index": template_indices, + "best_channel_index": best_channel_indices, + "spikes_per_unit": spikes_per_unit, + "dataset": [dataset] * num_units, + "dataset_path": [zarr_path] * num_units, } ) - if templates_df is None: - templates_df = new_entry - else: - templates_df = pd.concat([templates_df, new_entry]) - print(f"Added {num_units} units from dataset {dataset}") - templates_df.reset_index(inplace=True, drop=True) + all_dataframes.append(new_entry) + + # Concatenate all DataFrames into a single DataFrame + templates_df = pd.concat(all_dataframes, ignore_index=True) + templates_df.to_csv("templates.csv", index=False) # Upload to S3 if not dry_run: - bc.upload_file("templates.csv", bucket, "templates.csv") - else: + boto_client.upload_file("templates.csv", bucket, "templates.csv") + + if verbose: print("Dry run, not uploading") print(templates_df) return templates_df - if __name__ == "__main__": params = parser.parse_args() DRY_RUN = params.dry_run - SKIP_TEST = not params.no_skip_test - templates_df = consolidate_datasets(dry_run=DRY_RUN, skip_test_folder=SKIP_TEST) + templates_df = consolidate_datasets(dry_run=DRY_RUN) diff --git a/python/how_to_calculate_templates_from_dandisets.py b/python/how_to_calculate_templates_from_dandisets.py deleted file mode 100644 index e40ddfe..0000000 --- a/python/how_to_calculate_templates_from_dandisets.py +++ /dev/null @@ -1,208 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -# # How to estimate templates from Dandisets -# The purpose of this draft notebook is to showcase how templates can be calculated by means of the `SortingAnalyzer` object. -# - -# In[1]: - - -from dandi.dandiapi import DandiAPIClient -from spikeinterface.extractors import NwbRecordingExtractor, IblSortingExtractor - -client = DandiAPIClient.for_dandi_instance("dandi") - -# We specifiy a dataset by is dandiset_id and its asset path -dandiset_id = "000409" -dandiset = client.get_dandiset(dandiset_id) - -asset_path = "sub-KS042/sub-KS042_ses-8c552ddc-813e-4035-81cc-3971b57efe65_behavior+ecephys+image.nwb" -recording_asset = dandiset.get_asset_by_path(path=asset_path) -url = recording_asset.get_content_url(follow_redirects=True, strip_query=True) -file_path = url - - -# Note that this ElectricalSeries corresponds to the data from probe 00 -electrical_series_path = "acquisition/ElectricalSeriesAp00" -recording = NwbRecordingExtractor(file_path=file_path, stream_mode="remfile", electrical_series_path=electrical_series_path) -session_id = recording._file["general"]["session_id"][()].decode() -eid = session_id.split("-chunking")[0] # eid : experiment id - - -# We use the sorting extractor from the IBL spike sorting pipeline that matches with eid -from one.api import ONE -ONE.setup(base_url='https://openalyx.internationalbrainlab.org', silent=True) -one_instance = ONE(password='international') - - -# Then we match the available probes with the probe number in the electrical series path -pids, probes = one_instance.eid2pid(eid) -probe_number = electrical_series_path.split("Ap")[-1] - -sorting_pid = None -for pid, probe in zip(pids, probes): - probe_number_in_pid = probe[-2:] - if probe_number_in_pid == probe_number: - sorting_pid = pid - break - - -sorting = IblSortingExtractor(pid=sorting_pid, one=one_instance, good_clusters_only=True) - - -# We now have our sorting and recording objects. We perform some preprocessing on our recording and slice ouf objects so we only estimate templates from the last minutes of the data. - -# In[2]: - - -from spikeinterface.preprocessing import astype, phase_shift, common_reference, highpass_filter - -pre_processed_recording = common_reference( - highpass_filter(phase_shift(astype(recording=recording, dtype="float32")), freq_min=1.0) -) - - -# take first and last minute -sampling_frequency_recording = pre_processed_recording.sampling_frequency -sorting_sampling_frequency = sorting.sampling_frequency -num_samples = pre_processed_recording.get_num_samples() - -# Take the last 10 minutes of the recording -minutes = 10 -seconds = minutes * 60 -samples_before_end = int(seconds * sampling_frequency_recording) - -start_frame_recording = num_samples - samples_before_end -end_frame_recording = num_samples - -recording_end = pre_processed_recording.frame_slice( - start_frame=start_frame_recording, - end_frame=end_frame_recording -) - -# num_samples = sorting.get_num_frames() -samples_before_end = int(seconds * sorting_sampling_frequency) -start_frame_sorting = num_samples - samples_before_end -end_frame_sorting = num_samples - -sorting_end = sorting.frame_slice( - start_frame=start_frame_sorting, - end_frame=end_frame_sorting -) - - -# We now use the `SortingAnalyzer` object to estimate templates. - -# In[3]: - - -from spikeinterface.core import create_sorting_analyzer - -analyzer = create_sorting_analyzer(sorting_end, recording_end, sparse=False, folder=f"analyzer_{eid}") - - -random_spike_parameters = { - "method": "all", -} - - -template_extension_parameters = { - "ms_before": 3.0, - "ms_after": 5.0, - "operators": ["average", "std"], -} - -extensions = { - "random_spikes": random_spike_parameters, - "templates": template_extension_parameters, -} - -analyzer.compute_several_extensions( - extensions=extensions, - n_jobs=-1, - progress_bar=True, -) - - -# In[4]: - - -templates_extension = analyzer.get_extension("templates") -template_object = templates_extension.get_data(outputs="Templates") - - -# That's it. We now have our data in a templates object (note the outputs keyword on `get_data`). As a visual test that the pipeline works we show how the best chanenl (defined as the one with the maximum peak to peak amplitude) and plot some unit's templates for that channel. - -# In[5]: - - -import numpy as np - - - -def find_channels_with_max_peak_to_peak_vectorized(templates_array): - """ - Find the channel indices with the maximum peak-to-peak value in each waveform template - using a vectorized operation for improved performance. - - Parameters: - templates_array (numpy.ndarray): The waveform templates_array, typically a 3D array (units x time x channels). - - Returns: - numpy.ndarray: An array of indices of the channel with the maximum peak-to-peak value for each unit. - """ - # Compute the peak-to-peak values along the time axis (axis=1) for each channel of each unit - peak_to_peak_values = np.ptp(templates_array, axis=1) - - # Find the indices of the channel with the maximum peak-to-peak value for each unit - best_channels = np.argmax(peak_to_peak_values, axis=1) - - return best_channels - - - - -# In[6]: - - -import matplotlib.pyplot as plt - -# Adjust global font size -plt.rcParams.update({"font.size": 18}) - -unit_ids = template_object.unit_ids -best_channels = find_channels_with_max_peak_to_peak_vectorized(template_object.templates_array) - - -num_columns = 3 -num_rows = 3 - -fig, axs = plt.subplots(num_rows, num_columns, figsize=(15, 20), sharex=True, sharey=True) - -center = template_object.nbefore - -for unit_index, unit_id in enumerate(unit_ids[: num_columns * num_rows]): - row, col = divmod(unit_index, num_columns) - ax = axs[row, col] - best_channel = best_channels[unit_index] - - ax.plot(template_object.templates_array[unit_index, :, best_channel], linewidth=3, label="best channel", color="black") - - ax.axvline(center, linestyle="--", color="red", linewidth=1) - ax.axhline(0, linestyle="--", color="gray", linewidth=1) - ax.set_title(f"Unit {unit_id}") - - # Hide all spines and ticks - ax.tick_params(axis="both", which="both", length=0) - - # Hide all spines - for spine in ax.spines.values(): - spine.set_visible(False) - -plt.tight_layout() - -# Create the legend with specified colors -handles, labels = axs[0, 0].get_legend_handles_labels() -fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False) - diff --git a/python/upload_templates.py b/python/upload_templates.py index 1bb9c60..3fca8f7 100644 --- a/python/upload_templates.py +++ b/python/upload_templates.py @@ -1,7 +1,8 @@ +from pathlib import Path + import numpy as np from dandi.dandiapi import DandiAPIClient - from spikeinterface.extractors import ( NwbRecordingExtractor, IblSortingExtractor, @@ -21,6 +22,7 @@ from one.api import ONE import os +import time def find_channels_with_max_peak_to_peak_vectorized(templates): @@ -43,41 +45,49 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): return best_channels -client = DandiAPIClient.for_dandi_instance("dandi") - -dandiset_id = "000409" -dandiset = client.get_dandiset(dandiset_id) +# AWS credentials +aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") +aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") +bucket_name = "spikeinterface-template-database" +client_kwargs={"region_name": "us-east-2"} -valid_dandiset_path = lambda path: path.endswith(".nwb") and "ecephys" in path -dandiset_paths_with_ecephys = [ - asset.path for asset in dandiset.get_assets() if valid_dandiset_path(asset.path) -] -dandiset_paths_with_ecephys.sort() -dandiset_paths_with_ecephys = [ - path for path in dandiset_paths_with_ecephys if "KS" in path -] +# Parameters +minutes_by_the_end = 30 # How many minutes in the end of the recording to use for templates +upload_data = True +verbose = True +# Test data +do_testing_data = False +test_path = "sub-KS051/sub-KS051_ses-0a018f12-ee06-4b11-97aa-bbbff5448e9f_behavior+ecephys+image.nwb" ONE.setup(base_url="https://openalyx.internationalbrainlab.org", silent=True) one_instance = ONE(password="international") -for asset_path in dandiset_paths_with_ecephys[1:]: - # asset_path = "sub-KS051/sub-KS051_ses-0a018f12-ee06-4b11-97aa-bbbff5448e9f_behavior+ecephys+image.nwb" - print("-------------------") - print(asset_path) +client = DandiAPIClient.for_dandi_instance("dandi") +dandiset_id = "000409" +dandiset = client.get_dandiset(dandiset_id) + +has_ecephy_data = lambda path: path.endswith(".nwb") and "ecephys" in path +dandiset_paths = [asset.path for asset in dandiset.get_assets() if has_ecephy_data(asset.path)] +dandiset_paths.sort() +dandiset_paths = [path for path in dandiset_paths if "KS" in path] + +if do_testing_data: + dandiset_paths = [test_path] + +for asset_path in dandiset_paths[:]: + if verbose: + print("-------------------") + print(asset_path) recording_asset = dandiset.get_asset_by_path(path=asset_path) url = recording_asset.get_content_url(follow_redirects=True, strip_query=True) file_path = url - electrical_series_paths = ( - NwbRecordingExtractor.fetch_available_electrical_series_paths( - file_path=file_path, stream_mode="remfile" - ) + electrical_series_paths = NwbRecordingExtractor.fetch_available_electrical_series_paths( + file_path=file_path, stream_mode="remfile" ) - electrical_series_paths_ap = [ - path for path in electrical_series_paths if "Ap" in path.split("/")[-1] - ] + electrical_series_paths_ap = [path for path in electrical_series_paths if "Ap" in path.split("/")[-1]] for electrical_series_path in electrical_series_paths_ap: print(electrical_series_path) @@ -86,11 +96,13 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): stream_mode="remfile", electrical_series_path=electrical_series_path, ) + session_id = recording._file["general"]["session_id"][()].decode() eid = session_id.split("-chunking")[0] # eid : experiment id pids, probes = one_instance.eid2pid(eid) - print("pids", pids) - print("probes", probes) + if verbose: + print("pids", pids) + print("probes", probes) if len(probes) > 1: probe_number = electrical_series_path.split("Ap")[-1] @@ -123,44 +135,48 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): sorting_sampling_frequency = sorting.sampling_frequency num_samples = recording.get_num_samples() - minutes = 60 - samples_before_end = int(minutes * 60.0 * sampling_frequency_recording) + samples_before_end = int(minutes_by_the_end * 60.0 * sampling_frequency_recording) start_frame_recording = num_samples - samples_before_end end_frame_recording = num_samples - recording = recording.frame_slice( - start_frame=start_frame_recording, end_frame=end_frame_recording - ) - - samples_before_end = int(minutes * 60.0 * sorting_sampling_frequency) + recording = recording.frame_slice(start_frame=start_frame_recording, end_frame=end_frame_recording) + samples_before_end = int(minutes_by_the_end * 60.0 * sorting_sampling_frequency) start_frame_sorting = num_samples - samples_before_end end_frame_sorting = num_samples - sorting_end = sorting.frame_slice( - start_frame=start_frame_sorting, end_frame=end_frame_sorting - ) - - pre_processed_recording = common_reference( - highpass_filter( - phase_shift(astype(recording=recording, dtype="float32")), freq_min=1.0 - ) - ) - - folder_path = "./pre_processed_recording" - pre_processed_recording = pre_processed_recording.save_to_folder( + sorting_end = sorting.frame_slice(start_frame=start_frame_sorting, end_frame=end_frame_sorting) + + + # NWB Streaming is not working well with parallel pre=processing so we ave + folder_path = Path.cwd() / "local_copy" + folder_path.mkdir(exist_ok=True, parents=True) + + if verbose: + print("Saving Recording") + print(recording) + start_time = time.time() + + recording = recording.save_to_folder( folder=folder_path, overwrite=True, - n_jobs=1, - chunk_memory="1G", + n_jobs=6, + chunk_memory="1Gi", verbose=True, progress_bar=True, ) - analyzer = create_sorting_analyzer( - sorting_end, pre_processed_recording, sparse=False, folder=f"analyzer_{eid}" + if verbose: + end_time = time.time() + execution_time = end_time - start_time + print(f"Execution time: {execution_time/60.0: 2.2f} minutes") + + pre_processed_recording = common_reference( + highpass_filter(phase_shift(astype(recording=recording, dtype="float32")), freq_min=1.0) ) + analyzer = create_sorting_analyzer(sorting_end, pre_processed_recording, sparse=False, folder=f"analyzer_{eid}") + random_spike_parameters = { "method": "all", } @@ -182,13 +198,20 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): "noise_levels": noise_level_parameters, } + if verbose: + print("Computing extensions") + start_time = time.time() analyzer.compute_several_extensions( extensions=extensions, - n_jobs=3, + n_jobs=4, verbose=True, progress_bar=True, - chunk_memory="1G", + chunk_memory="500Mi", ) + if verbose: + end_time = time.time() + execution_time = end_time - start_time + print(f"Execution time: {execution_time/60.0: 2.2f} minutes") noise_level_extension = analyzer.get_extension("noise_levels") noise_level_data = noise_level_extension.get_data() @@ -199,59 +222,60 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): templates_extension = analyzer.get_extension("templates") templates_object = templates_extension.get_data(outputs="Templates") unit_ids = templates_object.unit_ids - best_channels = find_channels_with_max_peak_to_peak_vectorized( - templates_object.templates_array - ) + best_channel_index = find_channels_with_max_peak_to_peak_vectorized(templates_object.templates_array) templates_object.probe.model_name = probe_info[0]["model_name"] templates_object.probe.manufacturer = probe_info[0]["manufacturer"] templates_object.probe.serial_number = probe_info[0]["serial_number"] - # AWS credentials - aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") - aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") - - # Create a S3 file system object with explicit credentials - s3_kwargs = dict( - anon=False, - key=aws_access_key_id, - secret=aws_secret_access_key, - client_kwargs={"region_name": "us-east-2"}, - ) - s3 = s3fs.S3FileSystem(**s3_kwargs) + if do_testing_data: + dataset_name = "test_templates.zarr" + else: + path = asset_path.split("/")[-1] + dataset_name = f"{dandiset_id}_{path}_{sorting_pid}.zarr" + + if verbose: + print("Saving data to Zarr") + print(f"{dataset_name=}") + + if upload_data: + # Create a S3 file system object with explicit credentials + s3_kwargs = dict( + anon=False, + key=aws_access_key_id, + secret=aws_secret_access_key, + client_kwargs=client_kwargs + ) + s3 = s3fs.S3FileSystem(**s3_kwargs) - # Specify the S3 bucket and path - path = asset_path.split("/")[-1] - dataset_name = f"{dandiset_id}_{path}_{sorting_pid}" - store = s3fs.S3Map( - root=f"spikeinterface-template-database/{dataset_name}.zarr", s3=s3 - ) + # Specify the S3 bucket and path + s3_path = f"{bucket_name}/{dataset_name}" + store = s3fs.S3Map(root=s3_path, s3=s3) + else: + folder_path = Path.cwd() / f"{dataset_name}" + store = zarr.DirectoryStore(str(folder_path)) + # Save results to Zarr zarr_group = zarr.group(store=store, overwrite=True) - brain_area = sorting_end.get_property("brain_area") - zarr_group.create_dataset( - name="brain_area", data=brain_area, object_codec=numcodecs.VLenUTF8() - ) + zarr_group.create_dataset(name="brain_area", data=brain_area, object_codec=numcodecs.VLenUTF8()) spikes_per_unit = sorting_end.count_num_spikes_per_unit(outputs="array") + zarr_group.create_dataset(name="spikes_per_unit", data=spikes_per_unit, chunks=None, dtype="uint32") zarr_group.create_dataset( - name="spikes_per_unit", data=spikes_per_unit, chunks=None, dtype="int32" - ) - zarr_group.create_dataset( - name="best_channels", data=best_channels, chunks=None, dtype="int32" - ) - peak_to_peak = peak_to_peak_values = np.ptp( - templates_extension_data.templates_array, axis=1 + name="best_channel_index", + data=best_channel_index, + chunks=None, + dtype="uint32", ) + peak_to_peak = np.ptp(templates_extension_data.templates_array, axis=1) zarr_group.create_dataset(name="peak_to_peak", data=peak_to_peak) zarr_group.create_dataset( - name="channe_noise_levels", + name="channel_noise_levels", data=noise_level_data, chunks=None, dtype="float32", ) - # Now you can create a Zarr array using this setore + # Now you can create a Zarr array using this store templates_extension_data.add_templates_to_zarr_group(zarr_group=zarr_group) zarr_group_s3 = zarr_group - zarr.consolidate_metadata(zarr_group_s3.store)