Skip to content

Commit

Permalink
general upload script
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Jun 10, 2024
1 parent efe60ea commit 4495da3
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 396 deletions.
202 changes: 100 additions & 102 deletions python/consolidate_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,134 +9,132 @@
parser = ArgumentParser(description="Consolidate datasets from spikeinterface template database")

parser.add_argument("--dry-run", action="store_true", help="Dry run (no upload)")
parser.add_argument("--no-skip-test", action="store_true", help="Skip test datasets")
parser.add_argument("--bucket", type=str, help="S3 bucket name", default="spikeinterface-template-database")


def list_bucket_objects(
bucket: str,
boto_client: boto3.client,
prefix: str = "",
include_substrings: str | list[str] | None = None,
skip_substrings: str | list[str] | None = None,
):
# get all objects for session from s3
paginator = boto_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Prefix=prefix, Bucket=bucket)
keys = []

if include_substrings is not None:
if isinstance(include_substrings, str):
include_substrings = [include_substrings]
if skip_substrings is not None:
if isinstance(skip_substrings, str):
skip_substrings = [skip_substrings]

for page in pages:
for item in page.get("Contents", []):
key = item["Key"]
if include_substrings is None and skip_substrings is None:
keys.append(key)
else:
if skip_substrings is not None:
if any([s in key for s in skip_substrings]):
continue
if include_substrings is not None:
if all([s in key for s in include_substrings]):
keys.append(key)
return keys


def consolidate_datasets(
dry_run: bool = False, skip_test_folder: bool = True, bucket="spikeinterface-template-database"
):
### Find datasets and create dataframe with consolidated data
bc = boto3.client("s3")

# Each dataset is stored in a zarr folder, so we look for the .zattrs files
skip_substrings = ["test_templates"] if skip_test_folder else None
keys = list_bucket_objects(bucket, boto_client=bc, include_substrings=".zattrs", skip_substrings=skip_substrings)
datasets = [k.split("/")[0] for k in keys]
print(f"Found {len(datasets)} datasets to consolidate\n")

templates_df = None

# Loop over datasets and extract relevant information
for dataset in datasets:

def list_zarr_directories(bucket_name, boto_client=None):
"""Lists top-level Zarr directory keys in an S3 bucket.
Parameters
----------
bucket_name : str
The name of the S3 bucket to search.
boto_client : boto3.client, optional
An existing Boto3 S3 client. If not provided, a new client will be created.
Returns
-------
zarr_directories : list
A list of strings representing the full S3 keys (paths) of top-level Zarr directories
found in the bucket.
"""

boto_client = boto_client or boto3.client('s3')
zarr_directories = set()

paginator = boto_client.get_paginator('list_objects_v2')
for page in paginator.paginate(Bucket=bucket_name, Delimiter='/'):
for prefix in page.get('CommonPrefixes', []):
key = prefix['Prefix']
if key.endswith('.zarr/'):
zarr_directories.add(key.rstrip('/'))

return list(zarr_directories)

def consolidate_datasets(dry_run: bool = False, verbose: bool = False):
"""Consolidates data from Zarr datasets within an S3 bucket.
Parameters
----------
dry_run : bool, optional
If True, do not upload the consolidated data to S3. Defaults to False.
verbose : bool, optional
If True, print additional information during processing. Defaults to False.
Returns
-------
pandas.DataFrame
A DataFrame containing the consolidated data from all Zarr datasets.
Raises
------
FileNotFoundError
If no Zarr datasets are found in the specified bucket.
"""

bucket="spikeinterface-template-database"
boto_client = boto3.client("s3")

# Get list of Zarr directories, excluding test datasets
zarr_datasets = list_zarr_directories(bucket_name=bucket, boto_client=boto_client)
datasets_to_avoid = ["test_templates.zarr"]
zarr_datasets = [d for d in zarr_datasets if d not in datasets_to_avoid]

if not zarr_datasets:
raise FileNotFoundError(f"No Zarr datasets found in bucket: {bucket}")
if verbose:
print(f"Found {len(zarr_datasets)} datasets to consolidate\n")

# Initialize list to collect DataFrames for each dataset
all_dataframes = []

for dataset in zarr_datasets:
print(f"Processing dataset {dataset}")
zarr_path = f"s3://{bucket}/{dataset}"
zarr_group = zarr.open_consolidated(zarr_path, storage_options=dict(anon=True))

templates = Templates.from_zarr_group(zarr_group)

# Extract data efficiently using NumPy arrays
num_units = templates.num_units
dataset_list = [dataset] * num_units
dataset_path = [zarr_path] * num_units
template_idxs = np.arange(num_units)
best_channel_idxs = zarr_group.get("best_channels", None)
brain_areas = zarr_group.get("brain_area", None)
peak_to_peaks = zarr_group.get("peak_to_peak", None)
spikes_per_units = zarr_group.get("spikes_per_unit", None)

# TODO: get probe name from probe metadata

probe_attributes = zarr_group["probe"]["annotations"].attrs.asdict()
template_indices = np.arange(num_units)
default_brain_area = ["unknown"] * num_units
brain_areas = zarr_group.get("brain_area", default_brain_area)
channel_depths = templates.get_channel_locations()[:, 1]
spikes_per_unit = zarr_group["spikes_per_unit"][:]
best_channel_indices = zarr_group["best_channel_index"][:]

depths = np.zeros(num_units)
amps = np.zeros(num_units)

if best_channel_idxs is not None:
best_channel_idxs = best_channel_idxs[:]
for i, best_channel_idx in enumerate(best_channel_idxs):
depths[i] = channel_depths[best_channel_idx]
if peak_to_peaks is None:
amps[i] = np.ptp(templates.templates_array[i, :, best_channel_idx])
else:
amps[i] = peak_to_peaks[i, best_channel_idx]
else:
depths = np.nan
amps = np.nan
best_channel_idxs = [-1] * num_units
spikes_per_units = [-1] * num_units
if brain_areas is not None:
brain_areas = brain_areas[:]
else:
brain_areas = ["unknwown"] * num_units
depth_best_channel = channel_depths[best_channel_indices]
peak_to_peak_best_channel = zarr_group["peak_to_peak"][template_indices, best_channel_indices]
noise_best_channel = zarr_group["channel_noise_levels"][best_channel_indices]
signal_to_noise_ratio_best_channel = peak_to_peak_best_channel / noise_best_channel

new_entry = pd.DataFrame(
data={
"dataset": dataset_list,
"dataset_path": dataset_path,
"probe": ["Neuropixels1.0"] * num_units,
"template_index": template_idxs,
"best_channel_id": best_channel_idxs,
{
"probe": [probe_attributes["model_name"]] * num_units,
"probe_manufacturer": [probe_attributes["manufacturer"]] * num_units,
"brain_area": brain_areas,
"depth_along_probe": depths,
"amplitude_uv": amps,
"spikes_per_unit": spikes_per_units,
"depth_along_probe": depth_best_channel,
"amplitude_uv": peak_to_peak_best_channel,
"noise_level_uv": noise_best_channel,
"signal_to_noise_ratio": signal_to_noise_ratio_best_channel,
"template_index": template_indices,
"best_channel_index": best_channel_indices,
"spikes_per_unit": spikes_per_unit,
"dataset": [dataset] * num_units,
"dataset_path": [zarr_path] * num_units,
}
)
if templates_df is None:
templates_df = new_entry
else:
templates_df = pd.concat([templates_df, new_entry])
print(f"Added {num_units} units from dataset {dataset}")

templates_df.reset_index(inplace=True, drop=True)
all_dataframes.append(new_entry)

# Concatenate all DataFrames into a single DataFrame
templates_df = pd.concat(all_dataframes, ignore_index=True)

templates_df.to_csv("templates.csv", index=False)

# Upload to S3
if not dry_run:
bc.upload_file("templates.csv", bucket, "templates.csv")
else:
boto_client.upload_file("templates.csv", bucket, "templates.csv")

if verbose:
print("Dry run, not uploading")
print(templates_df)

return templates_df


if __name__ == "__main__":
params = parser.parse_args()
DRY_RUN = params.dry_run
SKIP_TEST = not params.no_skip_test
templates_df = consolidate_datasets(dry_run=DRY_RUN, skip_test_folder=SKIP_TEST)
templates_df = consolidate_datasets(dry_run=DRY_RUN)
Loading

0 comments on commit 4495da3

Please sign in to comment.