Skip to content

Commit

Permalink
add new seqnn + fix gcs util
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Oct 26, 2023
1 parent e84e220 commit 1e30e64
Show file tree
Hide file tree
Showing 5 changed files with 531 additions and 30 deletions.
58 changes: 51 additions & 7 deletions src/baskerville/helpers/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,43 @@ def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None:
storage_client.download_blob_to_file(gcs_path, o)


def download_folder_from_gcs(gcs_dir: str, local_dir: str, bytes=True) -> None:
"""
Downloads a whole folder from GCS
Args:
gcs_dir: string path to GCS folder to download
local_dir: string path to download to
bytes: boolean flag indicating if gcs file contains bytes
Returns: None
"""
storage_client = _get_storage_client()
write_mode = "wb" if bytes else "w"
if not is_gcs_path(gcs_dir):
raise ValueError(f"gcs_dir is not a valid GCS path: {gcs_dir}")
bucket_name, gcs_object_prefix = split_gcs_uri(gcs_dir)
# Get the bucket from the client.
bucket = storage_client.bucket(bucket_name)

# Ensure local folder exists
if not os.path.exists(local_dir):
os.makedirs(local_dir)
# List all blobs with the given prefix (i.e., folder path).
blobs = bucket.list_blobs(prefix=gcs_object_prefix)
# Download each blob.
for blob in blobs:
# Compute the full path to which we'll download the blob.
blob_rel_path = os.path.relpath(blob.name, gcs_object_prefix)
local_blob_path = os.path.join(local_dir, blob_rel_path)

# Ensure the local directory structure exists
local_blob_dir = os.path.dirname(local_blob_path)
if not os.path.exists(local_blob_dir):
os.makedirs(local_blob_dir)
download_from_gcs(join(gcs_dir, blob_rel_path), local_blob_path, bytes=bytes)


def sync_dir_to_gcs(
local_dir: str, gcs_dir: str, verbose=False, recursive=False
) -> None:
Expand Down Expand Up @@ -120,7 +157,7 @@ def upload_folder_gcs(local_dir: str, gcs_dir: str) -> None:
"""
storage_client = _get_storage_client()
bucket_name = gcs_dir.split("//")[1].split("/")[0]
gcs_object_prefix = gcs_dir.split("//")[1].split("/")[1]
gcs_object_prefix = "/".join(gcs_dir.split("//")[1].split("/")[1:])
local_prefix = local_dir.split("/")[-1]
bucket = storage_client.bucket(bucket_name)
for filename in os.listdir(local_dir):
Expand Down Expand Up @@ -207,18 +244,25 @@ def get_filename_in_dir(files_dir: str, recursive: bool = False) -> List[str]:
return files


def download_rename_inputs(filepath: str, temp_dir: str) -> str:
def download_rename_inputs(filepath: str, temp_dir: str, is_dir: bool = False) -> str:
"""
Download file from gcs to local dir
Args:
filepath: GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME
temp_dir: local dir to download to
is_dir: boolean flag indicating if the filepath is a directory
Returns: new filepath in the local machine
"""
_, filename = split_gcs_uri(filepath)
if "/" in filename:
filename = filename.split("/")[-1]
download_from_gcs(filepath, f"{temp_dir}/{filename}")
return f"{temp_dir}/{filename}"
if is_dir:
download_folder_from_gcs(filepath, temp_dir)
dir_name = filepath.split("/")[-1]
return temp_dir
else:
_, filename = split_gcs_uri(filepath)
if "/" in filename:
filename = filename.split("/")[-1]
download_from_gcs(filepath, f"{temp_dir}/{filename}")
return f"{temp_dir}/{filename}"


def gcs_file_exist(gcs_path: str) -> bool:
Expand Down
59 changes: 59 additions & 0 deletions src/baskerville/helpers/h5_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,62 @@ def collect_h5(file_name, out_dir, num_procs):
final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S"))

final_h5_open.close()


def collect_h5_borzoi(out_dir, num_procs, sad_stat) -> None:
h5_file = "scores.h5"

# count sequences
num_seqs = 0
for pi in range(num_procs):
# open job
job_h5_file = "%s/job%d/%s" % (out_dir, pi, h5_file)
job_h5_open = h5py.File(job_h5_file, "r")
num_seqs += job_h5_open[sad_stat].shape[0]
seq_len = job_h5_open[sad_stat].shape[1]
num_targets = job_h5_open[sad_stat].shape[-1]
job_h5_open.close()

# initialize final h5
final_h5_file = "%s/%s" % (out_dir, h5_file)
final_h5_open = h5py.File(final_h5_file, "w")

# keep dict for string values
final_strings = {}

job0_h5_file = "%s/job0/%s" % (out_dir, h5_file)
job0_h5_open = h5py.File(job0_h5_file, "r")
for key in job0_h5_open.keys():
key_shape = list(job0_h5_open[key].shape)
key_shape[0] = num_seqs
key_shape = tuple(key_shape)
if job0_h5_open[key].dtype.char == "S":
final_strings[key] = []
else:
final_h5_open.create_dataset(
key, shape=key_shape, dtype=job0_h5_open[key].dtype
)

# set values
si = 0
for pi in range(num_procs):
# open job
job_h5_file = "%s/job%d/%s" % (out_dir, pi, h5_file)
job_h5_open = h5py.File(job_h5_file, "r")

# append to final
for key in job_h5_open.keys():
job_seqs = job_h5_open[key].shape[0]
if job_h5_open[key].dtype.char == "S":
final_strings[key] += list(job_h5_open[key])
else:
final_h5_open[key][si : si + job_seqs] = job_h5_open[key]

job_h5_open.close()
si += job_seqs

# create final string datasets
for key in final_strings:
final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S"))

final_h5_open.close()
20 changes: 20 additions & 0 deletions src/baskerville/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pickle


def load_extra_options(options_pkl_file, options):
"""
Args:
options_pkl_file: option file
options: existing options from command line
Returns:
options: updated options
"""
options_pkl = open(options_pkl_file, "rb")
new_options = pickle.load(options_pkl)
new_option_attrs = vars(new_options)
# Assuming 'options' is the existing options object
# Update the existing options with the new attributes
for attr_name, attr_value in new_option_attrs.items():
setattr(options, attr_name, attr_value)
options_pkl.close()
return options
20 changes: 1 addition & 19 deletions src/baskerville/scripts/hound_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
upload_folder_gcs,
download_rename_inputs,
)
from baskerville.helpers.utils import load_extra_options

"""
hound_snp.py
Expand Down Expand Up @@ -211,25 +212,6 @@ def main():
shutil.rmtree(temp_dir) # clean up temp dir


def load_extra_options(options_pkl_file, options):
"""
Args:
options_pkl_file: option file
options: existing options from command line
Returns:
options: updated options
"""
options_pkl = open(options_pkl_file, "rb")
new_options = pickle.load(options_pkl)
new_option_attrs = vars(new_options)
# Assuming 'options' is the existing options object
# Update the existing options with the new attributes
for attr_name, attr_value in new_option_attrs.items():
setattr(options, attr_name, attr_value)
options_pkl.close()
return options


################################################################################
# __main__
################################################################################
Expand Down
Loading

0 comments on commit 1e30e64

Please sign in to comment.