Skip to content

Commit

Permalink
add-gcs
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Sep 15, 2023
1 parent cf08b90 commit d407a63
Show file tree
Hide file tree
Showing 3 changed files with 479 additions and 0 deletions.
189 changes: 189 additions & 0 deletions src/baskerville/helpers/gcs_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import os
import logging

from base64 import b64decode
from json import loads
from os.path import exists, join, isfile
from re import match
from typing import List

from google.cloud.storage import Client
from google.auth.exceptions import DefaultCredentialsError

logger = logging.getLogger(__name__)
PREFIX = ""

logger = logging.getLogger(__name__)

try:
# Attempt to infer credentials from environment
storage_client = Client()
logger.info("Inferred credentials from environment")
except DefaultCredentialsError:
try:
# Attempt to load JSON credentials from GOOGLE_APPLICATION_CREDENTIALS
storage_client = Client.from_service_account_info(
os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
)
logger.info("Loaded credentials from GOOGLE_APPLICATION_CREDENTIALS")
except AttributeError:
# Attempt to load JSON credentials from base64 encoded string
storage_client = Client.from_service_account_info(
loads(
b64decode(os.environ["GOOGLE_APPLICATION_CREDENTIALS"]).decode("utf-8")
)
)
logger.info("Loaded credentials from base64 encoded string")


def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None:
"""
Downloads a file from GCS
Args:
gcs_path: string path to GCS file to download
local_path: string path to download to
bytes: boolean flag indicating if gcs file contains bytes
Returns: None
"""
write_mode = "wb" if bytes else "w"
with open(local_path, write_mode) as o:
storage_client.download_blob_to_file(gcs_path, o)


def sync_dir_to_gcs(
local_dir: str, gcs_dir: str, verbose=False, recursive=False
) -> None:
"""
Copies all files in a local directory to the gcs directory
Args:
local_dir: string local directory path to upload from
gcs_dir: string GCS destination path. Will create folders that do not exist.
verbose: boolean flag to print logging statements
recursive: boolean flag to recursively upload files in subdirectories
Returns: None
"""
if not is_gcs_path(gcs_dir):
raise ValueError(f"gcs_dir is not a valid GCS path: {gcs_dir}")

if not exists(local_dir):
raise FileNotFoundError(f"local_dir does not exist: {local_dir}")

local_files = os.listdir(local_dir)
bucket_name, gcs_object_prefix = split_gcs_uri(gcs_dir)
bucket = storage_client.bucket(bucket_name)

for filename in local_files:
gcs_object_name = join(gcs_object_prefix, filename)
local_file = join(local_dir, filename)
if recursive and not isfile(local_file):
sync_dir_to_gcs(
local_file,
f"gs://{join(bucket_name, gcs_object_name)}",
verbose=verbose,
recursive=recursive,
)
elif not isfile(local_file):
pass
else:
blob = bucket.blob(gcs_object_name)
if verbose:
print(
f"Uploading {local_file} to gs://{join(bucket_name, gcs_object_name)}"
)
blob.upload_from_filename(local_file)


def upload_folder_gcs(local_dir: str, gcs_dir: str) -> None:
"""
Copies all files in a local directory to the gcs directory
Args:
local_dir: string local directory path to upload from
gcs_dir: string GCS destination path. Will create folders that do not exist.
Returns: None
"""
bucket_name = gcs_dir.split("//")[1].split("/")[0]
gcs_object_prefix = gcs_dir.split("//")[1].split("/")[1]
local_prefix = local_dir.split("/")[-1]
bucket = storage_client.bucket(bucket_name)
for filename in os.listdir(local_dir):
gcs_object_name = f"{gcs_object_prefix}/{local_prefix}/{filename}"
local_file = join(local_dir, filename)
blob = bucket.blob(gcs_object_name)
blob.upload_from_filename(local_file)


def gcs_join(*args):
args = [arg.replace("gs://", "").strip("/") for arg in args]
return "gs://" + join(*args)


def split_gcs_uri(gcs_uri: str) -> tuple:
"""
Splits a GCS bucket and object_name from a GCS URI
Args:
gcs_uri: string GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME
Returns: bucket_name, object_name
"""
matches = match("gs://(.*?)/(.*)", gcs_uri)
if matches:
return matches.groups()
else:
raise ValueError(
f"{gcs_uri} does not match expected format: gs://BUCKET_NAME/OBJECT_NAME"
)


def is_gcs_path(gcs_path: str) -> bool:
"""
Returns True if the string passed starts with gs://
Args:
gcs_path: string path to check
Returns: Boolean flag indicating the gcs_path starts with gs://
"""
return gcs_path.startswith("gs://")


def get_filename_in_dir(files_dir: str, recursive: bool = False) -> List[str]:
"""
Returns list of filenames inside a directory.
"""
# currently only Niftidataset receives gs bucket paths, so this isn't necessary
# commenting out for now even though it is functional (lots of files)
if is_gcs_path(files_dir):
bucket_name, object_name = split_gcs_uri(files_dir)
blob_iterator = storage_client.list_blobs(bucket_name, prefix=object_name)
# return [str(blob) for blob in blob_iterator if "/" not in blob.name]
return [str(blob) for blob in blob_iterator]
else:
dir_contents = os.listdir(files_dir)
files = []
for entry in dir_contents:
entry_path = join(files_dir, entry)
if isfile(entry_path):
files.append(entry_path)
elif recursive:
files.extend(get_filename_in_dir(entry_path, recursive=recursive))
else:
pass
return files


def download_rename_inputs(filepath: str, temp_dir: str):
"""
Download file from gcs to local dir
Args:
filepath: GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME
Returns: new filepath in the local machine
"""
_, filename = split_gcs_uri(filepath)
if "/" in filename:
filename = filename.split("/")[-1]
download_from_gcs(filepath, f"{temp_dir}/{filename}")
return f"{temp_dir}/{filename}"
89 changes: 89 additions & 0 deletions src/baskerville/helpers/h5_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import h5py
import numpy as np


def collect_h5(file_name, out_dir, num_procs):
# count variants
num_variants = 0
for pi in range(num_procs):
# open job
job_h5_file = "%s/job%d/%s" % (out_dir, pi, file_name)
job_h5_open = h5py.File(job_h5_file, "r")
num_variants += len(job_h5_open["snp"])
job_h5_open.close()

# initialize final h5
final_h5_file = "%s/%s" % (out_dir, file_name)
final_h5_open = h5py.File(final_h5_file, "w")

# keep dict for string values
final_strings = {}

job0_h5_file = "%s/job0/%s" % (out_dir, file_name)
job0_h5_open = h5py.File(job0_h5_file, "r")
for key in job0_h5_open.keys():
if key in ["percentiles", "target_ids", "target_labels"]:
# copy
final_h5_open.create_dataset(key, data=job0_h5_open[key])

elif key[-4:] == "_pct":
values = np.zeros(job0_h5_open[key].shape)
final_h5_open.create_dataset(key, data=values)

elif job0_h5_open[key].dtype.char == "S":
final_strings[key] = []

elif job0_h5_open[key].ndim == 1:
final_h5_open.create_dataset(
key, shape=(num_variants,), dtype=job0_h5_open[key].dtype
)

else:
num_targets = job0_h5_open[key].shape[1]
final_h5_open.create_dataset(
key, shape=(num_variants, num_targets), dtype=job0_h5_open[key].dtype
)

job0_h5_open.close()

# set values
vi = 0
for pi in range(num_procs):
# open job
job_h5_file = "%s/job%d/%s" % (out_dir, pi, file_name)
job_h5_open = h5py.File(job_h5_file, "r")

# append to final
for key in job_h5_open.keys():
if key in ["percentiles", "target_ids", "target_labels"]:
# once is enough
pass

elif key[-4:] == "_pct":
# average
u_k1 = np.array(final_h5_open[key])
x_k = np.array(job_h5_open[key])
final_h5_open[key][:] = u_k1 + (x_k - u_k1) / (pi + 1)

else:
if job_h5_open[key].dtype.char == "S":
final_strings[key] += list(job_h5_open[key])
else:
job_variants = job_h5_open[key].shape[0]
try:
final_h5_open[key][vi : vi + job_variants] = job_h5_open[key]
except TypeError as e:
print(e)
print(
f"{job_h5_file} ${key} has the wrong shape. Remove this file and rerun"
)
exit()

vi += job_variants
job_h5_open.close()

# create final string datasets
for key in final_strings:
final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S"))

final_h5_open.close()
Loading

0 comments on commit d407a63

Please sign in to comment.