diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f953a8e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[tool.black] +line-length = 128 +target-version = ['py37', 'py38', 'py39'] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' \ No newline at end of file diff --git a/python/__init__.py b/python/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/consolidate_datasets.py b/python/consolidate_datasets.py index ed2e94c..0710efb 100644 --- a/python/consolidate_datasets.py +++ b/python/consolidate_datasets.py @@ -13,7 +13,7 @@ -def list_zarr_directories(bucket_name, boto_client=None): +def list_zarr_directories(bucket_name, boto_client=None) -> list[str]: """Lists top-level Zarr directory keys in an S3 bucket. Parameters diff --git a/python/upload_templates.py b/python/upload_templates.py index 3fca8f7..82529be 100644 --- a/python/upload_templates.py +++ b/python/upload_templates.py @@ -21,8 +21,10 @@ import numcodecs from one.api import ONE -import os import time +import os + +from consolidate_datasets import list_zarr_directories def find_channels_with_max_peak_to_peak_vectorized(templates): @@ -49,11 +51,12 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") bucket_name = "spikeinterface-template-database" -client_kwargs={"region_name": "us-east-2"} +client_kwargs = {"region_name": "us-east-2"} # Parameters minutes_by_the_end = 30 # How many minutes in the end of the recording to use for templates upload_data = True +overwite = False verbose = True # Test data @@ -75,9 +78,18 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): if do_testing_data: dandiset_paths = [test_path] -for asset_path in dandiset_paths[:]: +# Load already processed datasets +zarr_datasets = list_zarr_directories(bucket_name=bucket_name) +if verbose: + print(f"Found {len(zarr_datasets)} datasets already processed") + +dandiset_paths = np.random.choice(dandiset_paths, size=len(dandiset_paths), replace=False) +for asset_path in dandiset_paths: if verbose: - print("-------------------") + print("----------------------------------------------------------") + print("----------------------------------------------------------") + print("----------------------------------------------------------") + print(asset_path) recording_asset = dandiset.get_asset_by_path(path=asset_path) @@ -89,14 +101,18 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): ) electrical_series_paths_ap = [path for path in electrical_series_paths if "Ap" in path.split("/")[-1]] for electrical_series_path in electrical_series_paths_ap: - print(electrical_series_path) + print(f"{electrical_series_path=}") recording = NwbRecordingExtractor( file_path=file_path, stream_mode="remfile", electrical_series_path=electrical_series_path, ) - + + if verbose: + print("Recording") + print(recording) + session_id = recording._file["general"]["session_id"][()].decode() eid = session_id.split("-chunking")[0] # eid : experiment id pids, probes = one_instance.eid2pid(eid) @@ -116,6 +132,17 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): sorting_pid = pids[0] probe_number = "00" + if do_testing_data: + dataset_name = "test_templates.zarr" + else: + dandi_name = asset_path.split("/")[-1].split(".")[0] + dataset_name = f"{dandiset_id}_{dandi_name}_{sorting_pid}.zarr" + + if dataset_name in zarr_datasets and not overwite: + if verbose: + print(f"Dataset {dataset_name} already processed, skipping") + continue + sorting = IblSortingExtractor( pid=sorting_pid, one=one_instance, @@ -146,21 +173,20 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): end_frame_sorting = num_samples sorting_end = sorting.frame_slice(start_frame=start_frame_sorting, end_frame=end_frame_sorting) - # NWB Streaming is not working well with parallel pre=processing so we ave - folder_path = Path.cwd() / "local_copy" + folder_path = Path.cwd() / "build" / "local_copy" folder_path.mkdir(exist_ok=True, parents=True) - + if verbose: print("Saving Recording") print(recording) start_time = time.time() - + recording = recording.save_to_folder( folder=folder_path, overwrite=True, - n_jobs=6, + n_jobs=8, chunk_memory="1Gi", verbose=True, progress_bar=True, @@ -201,13 +227,15 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): if verbose: print("Computing extensions") start_time = time.time() + analyzer.compute_several_extensions( extensions=extensions, - n_jobs=4, + n_jobs=8, verbose=True, progress_bar=True, - chunk_memory="500Mi", + chunk_memory="250Mi", ) + if verbose: end_time = time.time() execution_time = end_time - start_time @@ -228,31 +256,21 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): templates_object.probe.manufacturer = probe_info[0]["manufacturer"] templates_object.probe.serial_number = probe_info[0]["serial_number"] - if do_testing_data: - dataset_name = "test_templates.zarr" - else: - path = asset_path.split("/")[-1] - dataset_name = f"{dandiset_id}_{path}_{sorting_pid}.zarr" - if verbose: print("Saving data to Zarr") print(f"{dataset_name=}") - + if upload_data: # Create a S3 file system object with explicit credentials - s3_kwargs = dict( - anon=False, - key=aws_access_key_id, - secret=aws_secret_access_key, - client_kwargs=client_kwargs - ) + s3_kwargs = dict(anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs) s3 = s3fs.S3FileSystem(**s3_kwargs) # Specify the S3 bucket and path s3_path = f"{bucket_name}/{dataset_name}" store = s3fs.S3Map(root=s3_path, s3=s3) else: - folder_path = Path.cwd() / f"{dataset_name}" + folder_path = Path.cwd() / "build" / f"{dataset_name}" + folder_path.mkdir(exist_ok=True, parents=True) store = zarr.DirectoryStore(str(folder_path)) # Save results to Zarr