diff --git a/src/baskerville/helpers/gcs_utils.py b/src/baskerville/helpers/gcs_utils.py new file mode 100644 index 0000000..ef276d4 --- /dev/null +++ b/src/baskerville/helpers/gcs_utils.py @@ -0,0 +1,189 @@ +import os +import logging + +from base64 import b64decode +from json import loads +from os.path import exists, join, isfile +from re import match +from typing import List + +from google.cloud.storage import Client +from google.auth.exceptions import DefaultCredentialsError + +logger = logging.getLogger(__name__) +PREFIX = "" + +logger = logging.getLogger(__name__) + +try: + # Attempt to infer credentials from environment + storage_client = Client() + logger.info("Inferred credentials from environment") +except DefaultCredentialsError: + try: + # Attempt to load JSON credentials from GOOGLE_APPLICATION_CREDENTIALS + storage_client = Client.from_service_account_info( + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] + ) + logger.info("Loaded credentials from GOOGLE_APPLICATION_CREDENTIALS") + except AttributeError: + # Attempt to load JSON credentials from base64 encoded string + storage_client = Client.from_service_account_info( + loads( + b64decode(os.environ["GOOGLE_APPLICATION_CREDENTIALS"]).decode("utf-8") + ) + ) + logger.info("Loaded credentials from base64 encoded string") + + +def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None: + """ + Downloads a file from GCS + Args: + gcs_path: string path to GCS file to download + local_path: string path to download to + bytes: boolean flag indicating if gcs file contains bytes + + Returns: None + + """ + write_mode = "wb" if bytes else "w" + with open(local_path, write_mode) as o: + storage_client.download_blob_to_file(gcs_path, o) + + +def sync_dir_to_gcs( + local_dir: str, gcs_dir: str, verbose=False, recursive=False +) -> None: + """ + Copies all files in a local directory to the gcs directory + Args: + local_dir: string local directory path to upload from + gcs_dir: string GCS destination path. Will create folders that do not exist. + verbose: boolean flag to print logging statements + recursive: boolean flag to recursively upload files in subdirectories + + Returns: None + + """ + if not is_gcs_path(gcs_dir): + raise ValueError(f"gcs_dir is not a valid GCS path: {gcs_dir}") + + if not exists(local_dir): + raise FileNotFoundError(f"local_dir does not exist: {local_dir}") + + local_files = os.listdir(local_dir) + bucket_name, gcs_object_prefix = split_gcs_uri(gcs_dir) + bucket = storage_client.bucket(bucket_name) + + for filename in local_files: + gcs_object_name = join(gcs_object_prefix, filename) + local_file = join(local_dir, filename) + if recursive and not isfile(local_file): + sync_dir_to_gcs( + local_file, + f"gs://{join(bucket_name, gcs_object_name)}", + verbose=verbose, + recursive=recursive, + ) + elif not isfile(local_file): + pass + else: + blob = bucket.blob(gcs_object_name) + if verbose: + print( + f"Uploading {local_file} to gs://{join(bucket_name, gcs_object_name)}" + ) + blob.upload_from_filename(local_file) + + +def upload_folder_gcs(local_dir: str, gcs_dir: str) -> None: + """ + Copies all files in a local directory to the gcs directory + Args: + local_dir: string local directory path to upload from + gcs_dir: string GCS destination path. Will create folders that do not exist. + Returns: None + """ + bucket_name = gcs_dir.split("//")[1].split("/")[0] + gcs_object_prefix = gcs_dir.split("//")[1].split("/")[1] + local_prefix = local_dir.split("/")[-1] + bucket = storage_client.bucket(bucket_name) + for filename in os.listdir(local_dir): + gcs_object_name = f"{gcs_object_prefix}/{local_prefix}/{filename}" + local_file = join(local_dir, filename) + blob = bucket.blob(gcs_object_name) + blob.upload_from_filename(local_file) + + +def gcs_join(*args): + args = [arg.replace("gs://", "").strip("/") for arg in args] + return "gs://" + join(*args) + + +def split_gcs_uri(gcs_uri: str) -> tuple: + """ + Splits a GCS bucket and object_name from a GCS URI + Args: + gcs_uri: string GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME + + Returns: bucket_name, object_name + """ + matches = match("gs://(.*?)/(.*)", gcs_uri) + if matches: + return matches.groups() + else: + raise ValueError( + f"{gcs_uri} does not match expected format: gs://BUCKET_NAME/OBJECT_NAME" + ) + + +def is_gcs_path(gcs_path: str) -> bool: + """ + Returns True if the string passed starts with gs:// + Args: + gcs_path: string path to check + + Returns: Boolean flag indicating the gcs_path starts with gs:// + + """ + return gcs_path.startswith("gs://") + + +def get_filename_in_dir(files_dir: str, recursive: bool = False) -> List[str]: + """ + Returns list of filenames inside a directory. + """ + # currently only Niftidataset receives gs bucket paths, so this isn't necessary + # commenting out for now even though it is functional (lots of files) + if is_gcs_path(files_dir): + bucket_name, object_name = split_gcs_uri(files_dir) + blob_iterator = storage_client.list_blobs(bucket_name, prefix=object_name) + # return [str(blob) for blob in blob_iterator if "/" not in blob.name] + return [str(blob) for blob in blob_iterator] + else: + dir_contents = os.listdir(files_dir) + files = [] + for entry in dir_contents: + entry_path = join(files_dir, entry) + if isfile(entry_path): + files.append(entry_path) + elif recursive: + files.extend(get_filename_in_dir(entry_path, recursive=recursive)) + else: + pass + return files + + +def download_rename_inputs(filepath: str, temp_dir: str): + """ + Download file from gcs to local dir + Args: + filepath: GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME + Returns: new filepath in the local machine + """ + _, filename = split_gcs_uri(filepath) + if "/" in filename: + filename = filename.split("/")[-1] + download_from_gcs(filepath, f"{temp_dir}/{filename}") + return f"{temp_dir}/{filename}" diff --git a/src/baskerville/helpers/h5_utils.py b/src/baskerville/helpers/h5_utils.py new file mode 100644 index 0000000..11b91b5 --- /dev/null +++ b/src/baskerville/helpers/h5_utils.py @@ -0,0 +1,89 @@ +import h5py +import numpy as np + + +def collect_h5(file_name, out_dir, num_procs): + # count variants + num_variants = 0 + for pi in range(num_procs): + # open job + job_h5_file = "%s/job%d/%s" % (out_dir, pi, file_name) + job_h5_open = h5py.File(job_h5_file, "r") + num_variants += len(job_h5_open["snp"]) + job_h5_open.close() + + # initialize final h5 + final_h5_file = "%s/%s" % (out_dir, file_name) + final_h5_open = h5py.File(final_h5_file, "w") + + # keep dict for string values + final_strings = {} + + job0_h5_file = "%s/job0/%s" % (out_dir, file_name) + job0_h5_open = h5py.File(job0_h5_file, "r") + for key in job0_h5_open.keys(): + if key in ["percentiles", "target_ids", "target_labels"]: + # copy + final_h5_open.create_dataset(key, data=job0_h5_open[key]) + + elif key[-4:] == "_pct": + values = np.zeros(job0_h5_open[key].shape) + final_h5_open.create_dataset(key, data=values) + + elif job0_h5_open[key].dtype.char == "S": + final_strings[key] = [] + + elif job0_h5_open[key].ndim == 1: + final_h5_open.create_dataset( + key, shape=(num_variants,), dtype=job0_h5_open[key].dtype + ) + + else: + num_targets = job0_h5_open[key].shape[1] + final_h5_open.create_dataset( + key, shape=(num_variants, num_targets), dtype=job0_h5_open[key].dtype + ) + + job0_h5_open.close() + + # set values + vi = 0 + for pi in range(num_procs): + # open job + job_h5_file = "%s/job%d/%s" % (out_dir, pi, file_name) + job_h5_open = h5py.File(job_h5_file, "r") + + # append to final + for key in job_h5_open.keys(): + if key in ["percentiles", "target_ids", "target_labels"]: + # once is enough + pass + + elif key[-4:] == "_pct": + # average + u_k1 = np.array(final_h5_open[key]) + x_k = np.array(job_h5_open[key]) + final_h5_open[key][:] = u_k1 + (x_k - u_k1) / (pi + 1) + + else: + if job_h5_open[key].dtype.char == "S": + final_strings[key] += list(job_h5_open[key]) + else: + job_variants = job_h5_open[key].shape[0] + try: + final_h5_open[key][vi : vi + job_variants] = job_h5_open[key] + except TypeError as e: + print(e) + print( + f"{job_h5_file} ${key} has the wrong shape. Remove this file and rerun" + ) + exit() + + vi += job_variants + job_h5_open.close() + + # create final string datasets + for key in final_strings: + final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S")) + + final_h5_open.close() diff --git a/src/baskerville/scripts/hound_snp_gcs.py b/src/baskerville/scripts/hound_snp_gcs.py new file mode 100755 index 0000000..29024ea --- /dev/null +++ b/src/baskerville/scripts/hound_snp_gcs.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python +# Copyright 2023 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser +import pickle +import os +import tempfile +import shutil +import tensorflow as tf +from baskerville.helpers.gcs_utils import ( + upload_folder_gcs, + is_gcs_path, + download_rename_inputs, +) +from baskerville.snps import calculate_sad + +""" +hound_snp_gcs.py + +Compute variant effect predictions for SNPs in a VCF file. +Input files are in gcs and output is written to gcs location +""" + + +################################################################################ +# main +################################################################################ +def main(): + usage = "usage: %prog [options] " + parser = OptionParser(usage) + parser.add_option( + "-f", + dest="genome_fasta", + default=None, + help="Genome FASTA for sequences [Default: %default]", + ) + parser.add_option( + "-o", + dest="out_dir", + default="snp_out", + help="Output directory for tables and plots [Default: %default]", + ) + parser.add_option( + "-p", + dest="processes", + default=None, + type="int", + help="Number of processes, passed by multi script", + ) + parser.add_option( + "--rc", + dest="rc", + default=False, + action="store_true", + help="Average forward and reverse complement predictions [Default: %default]", + ) + parser.add_option( + "--shifts", + dest="shifts", + default="0", + type="str", + help="Ensemble prediction shifts [Default: %default]", + ) + parser.add_option( + "--stats", + dest="sad_stats", + default="SAD", + help="Comma-separated list of stats to save. [Default: %default]", + ) + parser.add_option( + "-t", + dest="targets_file", + default=None, + type="str", + help="File specifying target indexes and labels in table format", + ) + parser.add_option( + "-u", + dest="untransform_old", + default=False, + action="store_true", + help="Untransform old models [Default: %default]", + ) + (options, args) = parser.parse_args() + (options, args) = parser.parse_args() + genome_fasta = options.genome_fasta # gcs fasta file + targets_file = options.targets_file # gcs targets file + processes = options.processes + """ Assume that output_dir will be gcs""" + gcs_output_dir = options.out_dir + temp_dir = tempfile.mkdtemp() # create a temp dir for output + out_dir = temp_dir + "/output_dir" + if not os.path.isdir(out_dir): + os.mkdir(out_dir) + + if len(args) == 3: + # single worker + params_file = args[0] + model_file = args[1] + vcf_file = args[2] + options.out_dir = out_dir + elif len(args) == 4: + # multi separate + options_pkl_file = args[0] + params_file = args[1] + model_file = args[2] + vcf_file = args[3] + + # load options + if is_gcs_path(options_pkl_file): + options_pkl_file = download_rename_inputs(options_pkl_file, temp_dir) + options_pkl = open(options_pkl_file, "rb") + options = pickle.load(options_pkl) + options_pkl.close() + + # update output directory + options.out_dir = out_dir + + elif len(args) == 5: + # multi worker + options_pkl_file = args[0] + params_file = args[1] + model_file = args[2] + vcf_file = args[3] + worker_index = int(args[4]) + + # load options + if is_gcs_path(options_pkl_file): + options_pkl_file = download_rename_inputs(options_pkl_file, temp_dir) + options_pkl = open(options_pkl_file, "rb") + options = pickle.load(options_pkl) + options_pkl.close() + # update output directory + options.out_dir = out_dir + options.out_dir = "%s/job%d" % (options.out_dir, worker_index) + + else: + parser.error("Must provide parameters and model files and QTL VCF file") + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + options.shifts = [int(shift) for shift in options.shifts.split(",")] + options.sad_stats = options.sad_stats.split(",") + options.genome_fasta = genome_fasta + options.targets_file = targets_file + options.processes = processes + ################################################################# + # check if the program is run on GPU, else quit + physical_devices = tf.config.list_physical_devices() + # Check if a GPU is available + gpu_available = any(device.device_type == "GPU" for device in physical_devices) + + if gpu_available: + print("Running on GPU") + else: + print("Running on CPU") + raise SystemExit("Job terminated because it's running on CPU") + ################################################################# + # download input files from gcs to a local file + if is_gcs_path(params_file): + params_file = download_rename_inputs(params_file, temp_dir) + vcf_file = download_rename_inputs(vcf_file, temp_dir) + model_file = download_rename_inputs(model_file, temp_dir) + if options.genome_fasta is not None: + options.genome_fasta = download_rename_inputs( + options.genome_fasta, temp_dir + ) + if options.targets_file is not None: + options.targets_file = download_rename_inputs( + options.targets_file, temp_dir + ) + # call calculate_sad from hound_snp.py + # calculate SAD scores: + if options.processes is not None: + calculate_sad(params_file, model_file, vcf_file, worker_index, options) + else: + calculate_sad(params_file, model_file, vcf_file, 0, options) + # if the output dir is in gsc, sync it up + if is_gcs_path(gcs_output_dir): + upload_folder_gcs(options.out_dir, gcs_output_dir) + if os.path.isdir(temp_dir): + shutil.rmtree(temp_dir) # clean up temp dir + + +################################################################################ +# __main__ +################################################################################ +if __name__ == "__main__": + main()