diff --git a/src/baskerville/blocks.py b/src/baskerville/blocks.py index d71bb3e..82e9066 100644 --- a/src/baskerville/blocks.py +++ b/src/baskerville/blocks.py @@ -93,8 +93,8 @@ def conv_block( # normalize if norm_type == "batch-sync": - current = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum, gamma_initializer=norm_gamma + current = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, gamma_initializer=norm_gamma, synchronized=True )(current) elif norm_type == "batch": current = tf.keras.layers.BatchNormalization( @@ -221,8 +221,8 @@ def conv_dna( else: # normalize if norm_type == "batch-sync": - current = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum + current = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, synchronized=True )(current) elif norm_type == "batch": current = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current) @@ -303,8 +303,8 @@ def conv_nac( # normalize if norm_type == "batch-sync": - current = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum + current = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, synchronized=True )(current) elif norm_type == "batch": current = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current) @@ -479,11 +479,11 @@ def fpn_unet( # normalize if norm_type == "batch-sync": - current1 = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum + current1 = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, synchronized=True )(current1) - current2 = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum + current2 = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, synchronized=True )(current2) elif norm_type == "batch": current1 = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current1) @@ -570,8 +570,8 @@ def fpn1_unet( # normalize if norm_type == "batch-sync": - current1 = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum + current1 = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, synchronized=True )(current1) elif norm_type == "batch": current1 = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current1) @@ -648,11 +648,11 @@ def upsample_unet( # normalize if norm_type == "batch-sync": - current1 = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum + current1 = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, synchronized=True )(current1) - current2 = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum + current2 = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, synchronized=True )(current2) elif norm_type == "batch": current1 = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current1) @@ -745,8 +745,8 @@ def tconv_nac( # normalize if norm_type == "batch-sync": - current = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum + current = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, synchronized=True )(current) elif norm_type == "batch": current = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current) @@ -824,8 +824,8 @@ def conv_block_2d( # normalize if norm_type == "batch-sync": - current = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum, gamma_initializer=norm_gamma + current = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, gamma_initializer=norm_gamma, synchronized=True )(current) elif norm_type == "batch": current = tf.keras.layers.BatchNormalization( @@ -1880,8 +1880,8 @@ def dense_block( if norm_gamma is None: norm_gamma = "zeros" if residual else "ones" if norm_type == "batch-sync": - current = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum, gamma_initializer=norm_gamma + current = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, gamma_initializer=norm_gamma, synchronized=True )(current) elif norm_type == "batch": current = tf.keras.layers.BatchNormalization( @@ -1950,8 +1950,8 @@ def dense_nac( if norm_gamma is None: norm_gamma = "zeros" if residual else "ones" if norm_type == "batch-sync": - current = tf.keras.layers.experimental.SyncBatchNormalization( - momentum=bn_momentum, gamma_initializer=norm_gamma + current = tf.keras.layers.BatchNormalization( + momentum=bn_momentum, gamma_initializer=norm_gamma, synchronized=True )(current) elif norm_type == "batch": current = tf.keras.layers.BatchNormalization( diff --git a/src/baskerville/helpers/gcs_utils.py b/src/baskerville/helpers/gcs_utils.py new file mode 100644 index 0000000..72a853e --- /dev/null +++ b/src/baskerville/helpers/gcs_utils.py @@ -0,0 +1,234 @@ +# taken and modified from calico-ukbb-mri-ml repo +# https://github.com/calico/calicolabs-ukbb-mri-ml/tree/main/src/ukbb_mri_ml/helpers +# ========================================================================= + +import os +import logging +import pdb +from base64 import b64decode +from json import loads +from os.path import exists, join, isfile +from re import match +from typing import List + +from google.cloud.storage import Client +from google.auth.exceptions import DefaultCredentialsError + +logger = logging.getLogger(__name__) + +logger = logging.getLogger(__name__) + + +def _get_storage_client() -> Client: + """ + Returns: Google Cloud Storage Client + """ + try: + # Attempt to infer credentials from environment + storage_client = Client() + logger.info("Inferred credentials from environment") + except DefaultCredentialsError: + try: + # Attempt to load JSON credentials from GOOGLE_APPLICATION_CREDENTIALS + storage_client = Client.from_service_account_info( + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] + ) + logger.info("Loaded credentials from GOOGLE_APPLICATION_CREDENTIALS") + except AttributeError: + # Attempt to load JSON credentials from base64 encoded string + storage_client = Client.from_service_account_info( + loads( + b64decode(os.environ["GOOGLE_APPLICATION_CREDENTIALS"]).decode( + "utf-8" + ) + ) + ) + logger.info("Loaded credentials from base64 encoded string") + return storage_client + + +def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None: + """ + Downloads a file from GCS + Args: + gcs_path: string path to GCS file to download + local_path: string path to download to + bytes: boolean flag indicating if gcs file contains bytes + + Returns: None + + """ + storage_client = _get_storage_client() + write_mode = "wb" if bytes else "w" + with open(local_path, write_mode) as o: + storage_client.download_blob_to_file(gcs_path, o) + + +def sync_dir_to_gcs( + local_dir: str, gcs_dir: str, verbose=False, recursive=False +) -> None: + """ + Copies all files in a local directory to the gcs directory + Args: + local_dir: string local directory path to upload from + gcs_dir: string GCS destination path. Will create folders that do not exist. + verbose: boolean flag to print logging statements + recursive: boolean flag to recursively upload files in subdirectories + + Returns: None + + """ + storage_client = _get_storage_client() + if not is_gcs_path(gcs_dir): + raise ValueError(f"gcs_dir is not a valid GCS path: {gcs_dir}") + + if not exists(local_dir): + raise FileNotFoundError(f"local_dir does not exist: {local_dir}") + + local_files = os.listdir(local_dir) + bucket_name, gcs_object_prefix = split_gcs_uri(gcs_dir) + bucket = storage_client.bucket(bucket_name) + + for filename in local_files: + gcs_object_name = join(gcs_object_prefix, filename) + local_file = join(local_dir, filename) + if recursive and not isfile(local_file): + sync_dir_to_gcs( + local_file, + f"gs://{join(bucket_name, gcs_object_name)}", + verbose=verbose, + recursive=recursive, + ) + elif not isfile(local_file): + pass + else: + blob = bucket.blob(gcs_object_name) + if verbose: + print( + f"Uploading {local_file} to gs://{join(bucket_name, gcs_object_name)}" + ) + blob.upload_from_filename(local_file) + + +def upload_folder_gcs(local_dir: str, gcs_dir: str) -> None: + """ + Copies all files in a local directory to the gcs directory + Args: + local_dir: string local directory path to upload from + gcs_dir: string GCS destination path. Will create folders that do not exist. + Returns: None + """ + storage_client = _get_storage_client() + bucket_name = gcs_dir.split("//")[1].split("/")[0] + gcs_object_prefix = gcs_dir.split("//")[1].split("/")[1] + local_prefix = local_dir.split("/")[-1] + bucket = storage_client.bucket(bucket_name) + for filename in os.listdir(local_dir): + gcs_object_name = f"{gcs_object_prefix}/{local_prefix}/{filename}" + local_file = join(local_dir, filename) + blob = bucket.blob(gcs_object_name) + blob.upload_from_filename(local_file) + + +def upload_file_gcs(local_path: str, gcs_path: str, bytes=True) -> None: + """ + Upload a file to gcs + Args: + local_path: local path to file + gcs_path: string GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME + + Returns: None + """ + storage_client = _get_storage_client() + bucket_name = gcs_path.split("//")[1].split("/")[0] + bucket = storage_client.bucket(bucket_name) + gcs_object_prefix = gcs_path.split("//")[1].split("/")[1] + filename = local_path.split("/")[-1] + blob = bucket.blob(f"{gcs_object_prefix}/{filename}") + blob.upload_from_filename(local_path) + + +def gcs_join(*args): + args = [arg.replace("gs://", "").strip("/") for arg in args] + return "gs://" + join(*args) + + +def split_gcs_uri(gcs_uri: str) -> tuple: + """ + Splits a GCS bucket and object_name from a GCS URI + Args: + gcs_uri: string GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME + + Returns: bucket_name, object_name + """ + matches = match("gs://(.*?)/(.*)", gcs_uri) + if matches: + return matches.groups() + else: + raise ValueError( + f"{gcs_uri} does not match expected format: gs://BUCKET_NAME/OBJECT_NAME" + ) + + +def is_gcs_path(gcs_path: str) -> bool: + """ + Returns True if the string passed starts with gs:// + Args: + gcs_path: string path to check + + Returns: Boolean flag indicating the gcs_path starts with gs:// + + """ + return gcs_path.startswith("gs://") + + +def get_filename_in_dir(files_dir: str, recursive: bool = False) -> List[str]: + """ + Returns list of filenames inside a directory. + """ + # currently only Niftidataset receives gs bucket paths, so this isn't necessary + # commenting out for now even though it is functional (lots of files) + storage_client = _get_storage_client() + if is_gcs_path(files_dir): + bucket_name, object_name = split_gcs_uri(files_dir) + blob_iterator = storage_client.list_blobs(bucket_name, prefix=object_name) + return [str(blob) for blob in blob_iterator] + dir_contents = os.listdir(files_dir) + files = [] + for entry in dir_contents: + entry_path = join(files_dir, entry) + if isfile(entry_path): + files.append(entry_path) + elif recursive: + files.extend(get_filename_in_dir(entry_path, recursive=recursive)) + else: + print("Nothing happened here") + pass + return files + + +def download_rename_inputs(filepath: str, temp_dir: str) -> str: + """ + Download file from gcs to local dir + Args: + filepath: GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME + Returns: new filepath in the local machine + """ + _, filename = split_gcs_uri(filepath) + if "/" in filename: + filename = filename.split("/")[-1] + download_from_gcs(filepath, f"{temp_dir}/{filename}") + return f"{temp_dir}/{filename}" + + +def gcs_file_exist(gcs_path: str) -> bool: + """ + check if a file exist in gcs + params: gcs_path + returns: true/false + """ + storage_client = _get_storage_client() + bucket, filename = split_gcs_uri(gcs_path) + bucket = storage_client.bucket(bucket) + blob = bucket.blob(filename) + return blob.exists() diff --git a/src/baskerville/helpers/h5_utils.py b/src/baskerville/helpers/h5_utils.py new file mode 100644 index 0000000..11b91b5 --- /dev/null +++ b/src/baskerville/helpers/h5_utils.py @@ -0,0 +1,89 @@ +import h5py +import numpy as np + + +def collect_h5(file_name, out_dir, num_procs): + # count variants + num_variants = 0 + for pi in range(num_procs): + # open job + job_h5_file = "%s/job%d/%s" % (out_dir, pi, file_name) + job_h5_open = h5py.File(job_h5_file, "r") + num_variants += len(job_h5_open["snp"]) + job_h5_open.close() + + # initialize final h5 + final_h5_file = "%s/%s" % (out_dir, file_name) + final_h5_open = h5py.File(final_h5_file, "w") + + # keep dict for string values + final_strings = {} + + job0_h5_file = "%s/job0/%s" % (out_dir, file_name) + job0_h5_open = h5py.File(job0_h5_file, "r") + for key in job0_h5_open.keys(): + if key in ["percentiles", "target_ids", "target_labels"]: + # copy + final_h5_open.create_dataset(key, data=job0_h5_open[key]) + + elif key[-4:] == "_pct": + values = np.zeros(job0_h5_open[key].shape) + final_h5_open.create_dataset(key, data=values) + + elif job0_h5_open[key].dtype.char == "S": + final_strings[key] = [] + + elif job0_h5_open[key].ndim == 1: + final_h5_open.create_dataset( + key, shape=(num_variants,), dtype=job0_h5_open[key].dtype + ) + + else: + num_targets = job0_h5_open[key].shape[1] + final_h5_open.create_dataset( + key, shape=(num_variants, num_targets), dtype=job0_h5_open[key].dtype + ) + + job0_h5_open.close() + + # set values + vi = 0 + for pi in range(num_procs): + # open job + job_h5_file = "%s/job%d/%s" % (out_dir, pi, file_name) + job_h5_open = h5py.File(job_h5_file, "r") + + # append to final + for key in job_h5_open.keys(): + if key in ["percentiles", "target_ids", "target_labels"]: + # once is enough + pass + + elif key[-4:] == "_pct": + # average + u_k1 = np.array(final_h5_open[key]) + x_k = np.array(job_h5_open[key]) + final_h5_open[key][:] = u_k1 + (x_k - u_k1) / (pi + 1) + + else: + if job_h5_open[key].dtype.char == "S": + final_strings[key] += list(job_h5_open[key]) + else: + job_variants = job_h5_open[key].shape[0] + try: + final_h5_open[key][vi : vi + job_variants] = job_h5_open[key] + except TypeError as e: + print(e) + print( + f"{job_h5_file} ${key} has the wrong shape. Remove this file and rerun" + ) + exit() + + vi += job_variants + job_h5_open.close() + + # create final string datasets + for key in final_strings: + final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S")) + + final_h5_open.close() diff --git a/src/baskerville/scripts/hound_snp.py b/src/baskerville/scripts/hound_snp.py index 0eb8a7d..f913060 100755 --- a/src/baskerville/scripts/hound_snp.py +++ b/src/baskerville/scripts/hound_snp.py @@ -18,6 +18,13 @@ import pickle import os from baskerville.snps import score_snps +import tempfile +import shutil +import tensorflow as tf +from baskerville.helpers.gcs_utils import ( + upload_folder_gcs, + download_rename_inputs, +) """ hound_snp.py @@ -92,8 +99,31 @@ def main(): action="store_true", help="Untransform old models [Default: %default]", ) + parser.add_option( + "--gcs", + dest="gcs", + default=False, + action="store_true", + help="Input and output are in gcs", + ) + parser.add_option( + "--require_gpu", + dest="require_gpu", + default=False, + action="store_true", + help="Only run on GPU", + ) (options, args) = parser.parse_args() + if options.gcs: + """Assume that output_dir will be gcs""" + gcs_output_dir = options.out_dir + temp_dir = tempfile.mkdtemp() # create a temp dir for output + out_dir = temp_dir + "/output_dir" + if not os.path.isdir(out_dir): + os.mkdir(out_dir) + options.out_dir = out_dir + if len(args) == 3: # single worker params_file = args[0] @@ -111,10 +141,9 @@ def main(): out_dir = options.out_dir # load options - options_pkl = open(options_pkl_file, "rb") - options = pickle.load(options_pkl) - options_pkl.close() - + if options.gcs: + options_pkl_file = download_rename_inputs(options_pkl_file, temp_dir) + options = load_extra_options(options_pkl_file, options) # update output directory options.out_dir = out_dir @@ -127,30 +156,78 @@ def main(): worker_index = int(args[4]) # load options - options_pkl = open(options_pkl_file, "rb") - options = pickle.load(options_pkl) - options_pkl.close() - + if options.gcs: + options_pkl_file = download_rename_inputs(options_pkl_file, temp_dir) + options = load_extra_options(options_pkl_file, options) # update output directory options.out_dir = "%s/job%d" % (options.out_dir, worker_index) else: parser.error("Must provide parameters and model files and QTL VCF file") - if options.targets_file is None: - parser.error("Must provide targets file") - if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(",")] options.snp_stats = options.snp_stats.split(",") - + if options.targets_file is None: + parser.error("Must provide targets file") + ################################################################# + # check if the program is run on GPU, else quit + physical_devices = tf.config.list_physical_devices() + # Check if a GPU is available + gpu_available = any(device.device_type == "GPU" for device in physical_devices) + + if gpu_available: + print("Running on GPU") + else: + print("Running on CPU") + if options.require_gpu: + raise SystemExit("Job terminated because it's running on CPU") + ################################################################# + # download input files from gcs to a local file + if options.gcs: + params_file = download_rename_inputs(params_file, temp_dir) + vcf_file = download_rename_inputs(vcf_file, temp_dir) + model_file = download_rename_inputs(model_file, temp_dir) + if options.genome_fasta is not None: + options.genome_fasta = download_rename_inputs( + options.genome_fasta, temp_dir + ) + if options.targets_file is not None: + options.targets_file = download_rename_inputs( + options.targets_file, temp_dir + ) + ################################################################# # calculate SAD scores: if options.processes is not None: score_snps(params_file, model_file, vcf_file, worker_index, options) else: score_snps(params_file, model_file, vcf_file, 0, options) + # if the output dir is in gcs, sync it up + if options.gcs: + upload_folder_gcs(options.out_dir, gcs_output_dir) + if os.path.isdir(temp_dir): + shutil.rmtree(temp_dir) # clean up temp dir + + +def load_extra_options(options_pkl_file, options): + """ + Args: + options_pkl_file: option file + options: existing options from command line + Returns: + options: updated options + """ + options_pkl = open(options_pkl_file, "rb") + new_options = pickle.load(options_pkl) + new_option_attrs = vars(new_options) + # Assuming 'options' is the existing options object + # Update the existing options with the new attributes + for attr_name, attr_value in new_option_attrs.items(): + setattr(options, attr_name, attr_value) + options_pkl.close() + return options ################################################################################ diff --git a/src/baskerville/vcf.py b/src/baskerville/vcf.py index 802b0b5..520be53 100644 --- a/src/baskerville/vcf.py +++ b/src/baskerville/vcf.py @@ -599,8 +599,9 @@ def vcf_snps( if validate_ref_fasta is not None: ref_n = len(snps[-1].ref_allele) snp_pos = snps[-1].pos - 1 - ref_snp = genome_open.fetch(snps[-1].chr, snp_pos, snp_pos + ref_n) - + ref_snp = genome_open.fetch( + snps[-1].chr, snp_pos, snp_pos + ref_n + ).upper() if snps[-1].ref_allele != ref_snp: if not flip_ref: # bail @@ -615,7 +616,7 @@ def vcf_snps( alt_n = len(snps[-1].alt_alleles[0]) ref_snp = genome_open.fetch( snps[-1].chr, snp_pos, snp_pos + alt_n - ) + ).upper() # if alt matches fasta reference if snps[-1].alt_alleles[0] == ref_snp: