-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cf08b90
commit d407a63
Showing
3 changed files
with
479 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import os | ||
import logging | ||
|
||
from base64 import b64decode | ||
from json import loads | ||
from os.path import exists, join, isfile | ||
from re import match | ||
from typing import List | ||
|
||
from google.cloud.storage import Client | ||
from google.auth.exceptions import DefaultCredentialsError | ||
|
||
logger = logging.getLogger(__name__) | ||
PREFIX = "" | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
try: | ||
# Attempt to infer credentials from environment | ||
storage_client = Client() | ||
logger.info("Inferred credentials from environment") | ||
except DefaultCredentialsError: | ||
try: | ||
# Attempt to load JSON credentials from GOOGLE_APPLICATION_CREDENTIALS | ||
storage_client = Client.from_service_account_info( | ||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] | ||
) | ||
logger.info("Loaded credentials from GOOGLE_APPLICATION_CREDENTIALS") | ||
except AttributeError: | ||
# Attempt to load JSON credentials from base64 encoded string | ||
storage_client = Client.from_service_account_info( | ||
loads( | ||
b64decode(os.environ["GOOGLE_APPLICATION_CREDENTIALS"]).decode("utf-8") | ||
) | ||
) | ||
logger.info("Loaded credentials from base64 encoded string") | ||
|
||
|
||
def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None: | ||
""" | ||
Downloads a file from GCS | ||
Args: | ||
gcs_path: string path to GCS file to download | ||
local_path: string path to download to | ||
bytes: boolean flag indicating if gcs file contains bytes | ||
Returns: None | ||
""" | ||
write_mode = "wb" if bytes else "w" | ||
with open(local_path, write_mode) as o: | ||
storage_client.download_blob_to_file(gcs_path, o) | ||
|
||
|
||
def sync_dir_to_gcs( | ||
local_dir: str, gcs_dir: str, verbose=False, recursive=False | ||
) -> None: | ||
""" | ||
Copies all files in a local directory to the gcs directory | ||
Args: | ||
local_dir: string local directory path to upload from | ||
gcs_dir: string GCS destination path. Will create folders that do not exist. | ||
verbose: boolean flag to print logging statements | ||
recursive: boolean flag to recursively upload files in subdirectories | ||
Returns: None | ||
""" | ||
if not is_gcs_path(gcs_dir): | ||
raise ValueError(f"gcs_dir is not a valid GCS path: {gcs_dir}") | ||
|
||
if not exists(local_dir): | ||
raise FileNotFoundError(f"local_dir does not exist: {local_dir}") | ||
|
||
local_files = os.listdir(local_dir) | ||
bucket_name, gcs_object_prefix = split_gcs_uri(gcs_dir) | ||
bucket = storage_client.bucket(bucket_name) | ||
|
||
for filename in local_files: | ||
gcs_object_name = join(gcs_object_prefix, filename) | ||
local_file = join(local_dir, filename) | ||
if recursive and not isfile(local_file): | ||
sync_dir_to_gcs( | ||
local_file, | ||
f"gs://{join(bucket_name, gcs_object_name)}", | ||
verbose=verbose, | ||
recursive=recursive, | ||
) | ||
elif not isfile(local_file): | ||
pass | ||
else: | ||
blob = bucket.blob(gcs_object_name) | ||
if verbose: | ||
print( | ||
f"Uploading {local_file} to gs://{join(bucket_name, gcs_object_name)}" | ||
) | ||
blob.upload_from_filename(local_file) | ||
|
||
|
||
def upload_folder_gcs(local_dir: str, gcs_dir: str) -> None: | ||
""" | ||
Copies all files in a local directory to the gcs directory | ||
Args: | ||
local_dir: string local directory path to upload from | ||
gcs_dir: string GCS destination path. Will create folders that do not exist. | ||
Returns: None | ||
""" | ||
bucket_name = gcs_dir.split("//")[1].split("/")[0] | ||
gcs_object_prefix = gcs_dir.split("//")[1].split("/")[1] | ||
local_prefix = local_dir.split("/")[-1] | ||
bucket = storage_client.bucket(bucket_name) | ||
for filename in os.listdir(local_dir): | ||
gcs_object_name = f"{gcs_object_prefix}/{local_prefix}/{filename}" | ||
local_file = join(local_dir, filename) | ||
blob = bucket.blob(gcs_object_name) | ||
blob.upload_from_filename(local_file) | ||
|
||
|
||
def gcs_join(*args): | ||
args = [arg.replace("gs://", "").strip("/") for arg in args] | ||
return "gs://" + join(*args) | ||
|
||
|
||
def split_gcs_uri(gcs_uri: str) -> tuple: | ||
""" | ||
Splits a GCS bucket and object_name from a GCS URI | ||
Args: | ||
gcs_uri: string GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME | ||
Returns: bucket_name, object_name | ||
""" | ||
matches = match("gs://(.*?)/(.*)", gcs_uri) | ||
if matches: | ||
return matches.groups() | ||
else: | ||
raise ValueError( | ||
f"{gcs_uri} does not match expected format: gs://BUCKET_NAME/OBJECT_NAME" | ||
) | ||
|
||
|
||
def is_gcs_path(gcs_path: str) -> bool: | ||
""" | ||
Returns True if the string passed starts with gs:// | ||
Args: | ||
gcs_path: string path to check | ||
Returns: Boolean flag indicating the gcs_path starts with gs:// | ||
""" | ||
return gcs_path.startswith("gs://") | ||
|
||
|
||
def get_filename_in_dir(files_dir: str, recursive: bool = False) -> List[str]: | ||
""" | ||
Returns list of filenames inside a directory. | ||
""" | ||
# currently only Niftidataset receives gs bucket paths, so this isn't necessary | ||
# commenting out for now even though it is functional (lots of files) | ||
if is_gcs_path(files_dir): | ||
bucket_name, object_name = split_gcs_uri(files_dir) | ||
blob_iterator = storage_client.list_blobs(bucket_name, prefix=object_name) | ||
# return [str(blob) for blob in blob_iterator if "/" not in blob.name] | ||
return [str(blob) for blob in blob_iterator] | ||
else: | ||
dir_contents = os.listdir(files_dir) | ||
files = [] | ||
for entry in dir_contents: | ||
entry_path = join(files_dir, entry) | ||
if isfile(entry_path): | ||
files.append(entry_path) | ||
elif recursive: | ||
files.extend(get_filename_in_dir(entry_path, recursive=recursive)) | ||
else: | ||
pass | ||
return files | ||
|
||
|
||
def download_rename_inputs(filepath: str, temp_dir: str): | ||
""" | ||
Download file from gcs to local dir | ||
Args: | ||
filepath: GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME | ||
Returns: new filepath in the local machine | ||
""" | ||
_, filename = split_gcs_uri(filepath) | ||
if "/" in filename: | ||
filename = filename.split("/")[-1] | ||
download_from_gcs(filepath, f"{temp_dir}/{filename}") | ||
return f"{temp_dir}/{filename}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import h5py | ||
import numpy as np | ||
|
||
|
||
def collect_h5(file_name, out_dir, num_procs): | ||
# count variants | ||
num_variants = 0 | ||
for pi in range(num_procs): | ||
# open job | ||
job_h5_file = "%s/job%d/%s" % (out_dir, pi, file_name) | ||
job_h5_open = h5py.File(job_h5_file, "r") | ||
num_variants += len(job_h5_open["snp"]) | ||
job_h5_open.close() | ||
|
||
# initialize final h5 | ||
final_h5_file = "%s/%s" % (out_dir, file_name) | ||
final_h5_open = h5py.File(final_h5_file, "w") | ||
|
||
# keep dict for string values | ||
final_strings = {} | ||
|
||
job0_h5_file = "%s/job0/%s" % (out_dir, file_name) | ||
job0_h5_open = h5py.File(job0_h5_file, "r") | ||
for key in job0_h5_open.keys(): | ||
if key in ["percentiles", "target_ids", "target_labels"]: | ||
# copy | ||
final_h5_open.create_dataset(key, data=job0_h5_open[key]) | ||
|
||
elif key[-4:] == "_pct": | ||
values = np.zeros(job0_h5_open[key].shape) | ||
final_h5_open.create_dataset(key, data=values) | ||
|
||
elif job0_h5_open[key].dtype.char == "S": | ||
final_strings[key] = [] | ||
|
||
elif job0_h5_open[key].ndim == 1: | ||
final_h5_open.create_dataset( | ||
key, shape=(num_variants,), dtype=job0_h5_open[key].dtype | ||
) | ||
|
||
else: | ||
num_targets = job0_h5_open[key].shape[1] | ||
final_h5_open.create_dataset( | ||
key, shape=(num_variants, num_targets), dtype=job0_h5_open[key].dtype | ||
) | ||
|
||
job0_h5_open.close() | ||
|
||
# set values | ||
vi = 0 | ||
for pi in range(num_procs): | ||
# open job | ||
job_h5_file = "%s/job%d/%s" % (out_dir, pi, file_name) | ||
job_h5_open = h5py.File(job_h5_file, "r") | ||
|
||
# append to final | ||
for key in job_h5_open.keys(): | ||
if key in ["percentiles", "target_ids", "target_labels"]: | ||
# once is enough | ||
pass | ||
|
||
elif key[-4:] == "_pct": | ||
# average | ||
u_k1 = np.array(final_h5_open[key]) | ||
x_k = np.array(job_h5_open[key]) | ||
final_h5_open[key][:] = u_k1 + (x_k - u_k1) / (pi + 1) | ||
|
||
else: | ||
if job_h5_open[key].dtype.char == "S": | ||
final_strings[key] += list(job_h5_open[key]) | ||
else: | ||
job_variants = job_h5_open[key].shape[0] | ||
try: | ||
final_h5_open[key][vi : vi + job_variants] = job_h5_open[key] | ||
except TypeError as e: | ||
print(e) | ||
print( | ||
f"{job_h5_file} ${key} has the wrong shape. Remove this file and rerun" | ||
) | ||
exit() | ||
|
||
vi += job_variants | ||
job_h5_open.close() | ||
|
||
# create final string datasets | ||
for key in final_strings: | ||
final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S")) | ||
|
||
final_h5_open.close() |
Oops, something went wrong.