Skip to content

Commit

Permalink
add grad to seqnn, clean up utils
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Nov 1, 2023
1 parent e84e220 commit c6df4ff
Show file tree
Hide file tree
Showing 6 changed files with 646 additions and 31 deletions.
60 changes: 52 additions & 8 deletions src/baskerville/helpers/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,43 @@ def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None:
storage_client.download_blob_to_file(gcs_path, o)


def download_folder_from_gcs(gcs_dir: str, local_dir: str, bytes=True) -> None:
"""
Downloads a whole folder from GCS
Args:
gcs_dir: string path to GCS folder to download
local_dir: string path to download to
bytes: boolean flag indicating if gcs file contains bytes
Returns: None
"""
storage_client = _get_storage_client()
write_mode = "wb" if bytes else "w"
if not is_gcs_path(gcs_dir):
raise ValueError(f"gcs_dir is not a valid GCS path: {gcs_dir}")
bucket_name, gcs_object_prefix = split_gcs_uri(gcs_dir)
# Get the bucket from the client.
bucket = storage_client.bucket(bucket_name)

# Ensure local folder exists
if not os.path.exists(local_dir):
os.makedirs(local_dir)
# List all blobs with the given prefix (i.e., folder path).
blobs = bucket.list_blobs(prefix=gcs_object_prefix)
# Download each blob.
for blob in blobs:
# Compute the full path to which we'll download the blob.
blob_rel_path = os.path.relpath(blob.name, gcs_object_prefix)
local_blob_path = os.path.join(local_dir, blob_rel_path)

# Ensure the local directory structure exists
local_blob_dir = os.path.dirname(local_blob_path)
if not os.path.exists(local_blob_dir):
os.makedirs(local_blob_dir)
download_from_gcs(join(gcs_dir, blob_rel_path), local_blob_path, bytes=bytes)


def sync_dir_to_gcs(
local_dir: str, gcs_dir: str, verbose=False, recursive=False
) -> None:
Expand Down Expand Up @@ -120,7 +157,7 @@ def upload_folder_gcs(local_dir: str, gcs_dir: str) -> None:
"""
storage_client = _get_storage_client()
bucket_name = gcs_dir.split("//")[1].split("/")[0]
gcs_object_prefix = gcs_dir.split("//")[1].split("/")[1]
gcs_object_prefix = "/".join(gcs_dir.split("//")[1].split("/")[1:])
local_prefix = local_dir.split("/")[-1]
bucket = storage_client.bucket(bucket_name)
for filename in os.listdir(local_dir):
Expand All @@ -142,7 +179,7 @@ def upload_file_gcs(local_path: str, gcs_path: str, bytes=True) -> None:
storage_client = _get_storage_client()
bucket_name = gcs_path.split("//")[1].split("/")[0]
bucket = storage_client.bucket(bucket_name)
gcs_object_prefix = gcs_path.split("//")[1].split("/")[1]
gcs_object_prefix = "/".join(gcs_path.split("//")[1].split("/")[1:])
filename = local_path.split("/")[-1]
blob = bucket.blob(f"{gcs_object_prefix}/{filename}")
blob.upload_from_filename(local_path)
Expand Down Expand Up @@ -207,18 +244,25 @@ def get_filename_in_dir(files_dir: str, recursive: bool = False) -> List[str]:
return files


def download_rename_inputs(filepath: str, temp_dir: str) -> str:
def download_rename_inputs(filepath: str, temp_dir: str, is_dir: bool = False) -> str:
"""
Download file from gcs to local dir
Args:
filepath: GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME
temp_dir: local dir to download to
is_dir: boolean flag indicating if the filepath is a directory
Returns: new filepath in the local machine
"""
_, filename = split_gcs_uri(filepath)
if "/" in filename:
filename = filename.split("/")[-1]
download_from_gcs(filepath, f"{temp_dir}/{filename}")
return f"{temp_dir}/{filename}"
if is_dir:
download_folder_from_gcs(filepath, temp_dir)
dir_name = filepath.split("/")[-1]
return temp_dir
else:
_, filename = split_gcs_uri(filepath)
if "/" in filename:
filename = filename.split("/")[-1]
download_from_gcs(filepath, f"{temp_dir}/{filename}")
return f"{temp_dir}/{filename}"


def gcs_file_exist(gcs_path: str) -> bool:
Expand Down
114 changes: 114 additions & 0 deletions src/baskerville/helpers/h5_baskerville_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import h5py
import numpy as np
import argparse


def collect_h5(file_name, out_dir, num_procs) -> None:
"""
Concatenate all output files together
Args:
file_name: filename containing output (sad.h5)
out_dir: directory containing output files
num_procs: number of processes
Returns: None
"""
# count variants
num_variants = 0
for pi in range(num_procs):
# open job
job_h5_file = "%s/job%d/%s" % (out_dir, pi, file_name)
job_h5_open = h5py.File(job_h5_file, "r")
num_variants += len(job_h5_open["snp"])
job_h5_open.close()

# initialize final h5
final_h5_file = "%s/%s" % (out_dir, file_name)
final_h5_open = h5py.File(final_h5_file, "w")

# keep dict for string values
final_strings = {}

job0_h5_file = "%s/job0/%s" % (out_dir, file_name)
job0_h5_open = h5py.File(job0_h5_file, "r")
for key in job0_h5_open.keys():
if key in ["percentiles", "target_ids", "target_labels"]:
# copy
final_h5_open.create_dataset(key, data=job0_h5_open[key])

elif key[-4:] == "_pct":
values = np.zeros(job0_h5_open[key].shape)
final_h5_open.create_dataset(key, data=values)

elif job0_h5_open[key].dtype.char == "S":
final_strings[key] = []

elif job0_h5_open[key].ndim == 1:
final_h5_open.create_dataset(
key, shape=(num_variants,), dtype=job0_h5_open[key].dtype
)

else:
num_targets = job0_h5_open[key].shape[1]
final_h5_open.create_dataset(
key, shape=(num_variants, num_targets), dtype=job0_h5_open[key].dtype
)

job0_h5_open.close()

# set values
vi = 0
for pi in range(num_procs):
# open job
job_h5_file = "%s/job%d/%s" % (out_dir, pi, file_name)
job_h5_open = h5py.File(job_h5_file, "r")

# append to final
for key in job_h5_open.keys():
if key in ["percentiles", "target_ids", "target_labels"]:
# once is enough
pass

elif key[-4:] == "_pct":
# average
u_k1 = np.array(final_h5_open[key])
x_k = np.array(job_h5_open[key])
final_h5_open[key][:] = u_k1 + (x_k - u_k1) / (pi + 1)

else:
if job_h5_open[key].dtype.char == "S":
final_strings[key] += list(job_h5_open[key])
else:
job_variants = job_h5_open[key].shape[0]
try:
final_h5_open[key][vi : vi + job_variants] = job_h5_open[key]
except TypeError as e:
raise Exception(
f"{job_h5_file} ${key} has the wrong shape. Remove this file and rerun, {e}"
)

vi += job_variants
job_h5_open.close()

# create final string datasets
for key in final_strings:
final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S"))

final_h5_open.close()


def main():
parser = argparse.ArgumentParser(description="Process and collect h5 files.")

parser.add_argument("file_name", type=str, help="Path to the input file.")
parser.add_argument(
"out_dir", type=str, help="Output directory for processed data."
)
parser.add_argument("num_procs", type=int, help="Number of processes to use.")

args = parser.parse_args()

collect_h5(args.file_name, args.out_dir, args.num_procs)


if __name__ == "__main__":
main()
59 changes: 59 additions & 0 deletions src/baskerville/helpers/h5_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,62 @@ def collect_h5(file_name, out_dir, num_procs):
final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S"))

final_h5_open.close()


def collect_h5_borzoi(out_dir, num_procs, sad_stat) -> None:
h5_file = "scores_f0c0.h5"

# count sequences
num_seqs = 0
for pi in range(num_procs):
# open job
job_h5_file = "%s/job%d/%s" % (out_dir, pi, h5_file)
job_h5_open = h5py.File(job_h5_file, "r")
num_seqs += job_h5_open[sad_stat].shape[0]
seq_len = job_h5_open[sad_stat].shape[1]
num_targets = job_h5_open[sad_stat].shape[-1]
job_h5_open.close()

# initialize final h5
final_h5_file = "%s/%s" % (out_dir, h5_file)
final_h5_open = h5py.File(final_h5_file, "w")

# keep dict for string values
final_strings = {}

job0_h5_file = "%s/job0/%s" % (out_dir, h5_file)
job0_h5_open = h5py.File(job0_h5_file, "r")
for key in job0_h5_open.keys():
key_shape = list(job0_h5_open[key].shape)
key_shape[0] = num_seqs
key_shape = tuple(key_shape)
if job0_h5_open[key].dtype.char == "S":
final_strings[key] = []
else:
final_h5_open.create_dataset(
key, shape=key_shape, dtype=job0_h5_open[key].dtype
)

# set values
si = 0
for pi in range(num_procs):
# open job
job_h5_file = "%s/job%d/%s" % (out_dir, pi, h5_file)
job_h5_open = h5py.File(job_h5_file, "r")

# append to final
for key in job_h5_open.keys():
job_seqs = job_h5_open[key].shape[0]
if job_h5_open[key].dtype.char == "S":
final_strings[key] += list(job_h5_open[key])
else:
final_h5_open[key][si : si + job_seqs] = job_h5_open[key]

job_h5_open.close()
si += job_seqs

# create final string datasets
for key in final_strings:
final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S"))

final_h5_open.close()
20 changes: 20 additions & 0 deletions src/baskerville/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pickle


def load_extra_options(options_pkl_file, options):
"""
Args:
options_pkl_file: option file
options: existing options from command line
Returns:
options: updated options
"""
options_pkl = open(options_pkl_file, "rb")
new_options = pickle.load(options_pkl)
new_option_attrs = vars(new_options)
# Assuming 'options' is the existing options object
# Update the existing options with the new attributes
for attr_name, attr_value in new_option_attrs.items():
setattr(options, attr_name, attr_value)
options_pkl.close()
return options
20 changes: 1 addition & 19 deletions src/baskerville/scripts/hound_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
upload_folder_gcs,
download_rename_inputs,
)
from baskerville.helpers.utils import load_extra_options

"""
hound_snp.py
Expand Down Expand Up @@ -211,25 +212,6 @@ def main():
shutil.rmtree(temp_dir) # clean up temp dir


def load_extra_options(options_pkl_file, options):
"""
Args:
options_pkl_file: option file
options: existing options from command line
Returns:
options: updated options
"""
options_pkl = open(options_pkl_file, "rb")
new_options = pickle.load(options_pkl)
new_option_attrs = vars(new_options)
# Assuming 'options' is the existing options object
# Update the existing options with the new attributes
for attr_name, attr_value in new_option_attrs.items():
setattr(options, attr_name, attr_value)
options_pkl.close()
return options


################################################################################
# __main__
################################################################################
Expand Down
Loading

0 comments on commit c6df4ff

Please sign in to comment.