diff --git a/src/baskerville/helpers/gcs_utils.py b/src/baskerville/helpers/gcs_utils.py index 72a853e..af398d0 100644 --- a/src/baskerville/helpers/gcs_utils.py +++ b/src/baskerville/helpers/gcs_utils.py @@ -64,6 +64,43 @@ def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None: storage_client.download_blob_to_file(gcs_path, o) +def download_folder_from_gcs(gcs_dir: str, local_dir: str, bytes=True) -> None: + """ + Downloads a whole folder from GCS + Args: + gcs_dir: string path to GCS folder to download + local_dir: string path to download to + bytes: boolean flag indicating if gcs file contains bytes + + Returns: None + + """ + storage_client = _get_storage_client() + write_mode = "wb" if bytes else "w" + if not is_gcs_path(gcs_dir): + raise ValueError(f"gcs_dir is not a valid GCS path: {gcs_dir}") + bucket_name, gcs_object_prefix = split_gcs_uri(gcs_dir) + # Get the bucket from the client. + bucket = storage_client.bucket(bucket_name) + + # Ensure local folder exists + if not os.path.exists(local_dir): + os.makedirs(local_dir) + # List all blobs with the given prefix (i.e., folder path). + blobs = bucket.list_blobs(prefix=gcs_object_prefix) + # Download each blob. + for blob in blobs: + # Compute the full path to which we'll download the blob. + blob_rel_path = os.path.relpath(blob.name, gcs_object_prefix) + local_blob_path = os.path.join(local_dir, blob_rel_path) + + # Ensure the local directory structure exists + local_blob_dir = os.path.dirname(local_blob_path) + if not os.path.exists(local_blob_dir): + os.makedirs(local_blob_dir) + download_from_gcs(join(gcs_dir, blob_rel_path), local_blob_path, bytes=bytes) + + def sync_dir_to_gcs( local_dir: str, gcs_dir: str, verbose=False, recursive=False ) -> None: @@ -120,7 +157,7 @@ def upload_folder_gcs(local_dir: str, gcs_dir: str) -> None: """ storage_client = _get_storage_client() bucket_name = gcs_dir.split("//")[1].split("/")[0] - gcs_object_prefix = gcs_dir.split("//")[1].split("/")[1] + gcs_object_prefix = "/".join(gcs_dir.split("//")[1].split("/")[1:]) local_prefix = local_dir.split("/")[-1] bucket = storage_client.bucket(bucket_name) for filename in os.listdir(local_dir): @@ -207,18 +244,25 @@ def get_filename_in_dir(files_dir: str, recursive: bool = False) -> List[str]: return files -def download_rename_inputs(filepath: str, temp_dir: str) -> str: +def download_rename_inputs(filepath: str, temp_dir: str, is_dir: bool = False) -> str: """ Download file from gcs to local dir Args: filepath: GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME + temp_dir: local dir to download to + is_dir: boolean flag indicating if the filepath is a directory Returns: new filepath in the local machine """ - _, filename = split_gcs_uri(filepath) - if "/" in filename: - filename = filename.split("/")[-1] - download_from_gcs(filepath, f"{temp_dir}/{filename}") - return f"{temp_dir}/{filename}" + if is_dir: + download_folder_from_gcs(filepath, temp_dir) + dir_name = filepath.split("/")[-1] + return temp_dir + else: + _, filename = split_gcs_uri(filepath) + if "/" in filename: + filename = filename.split("/")[-1] + download_from_gcs(filepath, f"{temp_dir}/{filename}") + return f"{temp_dir}/{filename}" def gcs_file_exist(gcs_path: str) -> bool: diff --git a/src/baskerville/helpers/h5_utils.py b/src/baskerville/helpers/h5_utils.py index 11b91b5..3b37d7a 100644 --- a/src/baskerville/helpers/h5_utils.py +++ b/src/baskerville/helpers/h5_utils.py @@ -87,3 +87,62 @@ def collect_h5(file_name, out_dir, num_procs): final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S")) final_h5_open.close() + + +def collect_h5_borzoi(out_dir, num_procs, sad_stat) -> None: + h5_file = "scores.h5" + + # count sequences + num_seqs = 0 + for pi in range(num_procs): + # open job + job_h5_file = "%s/job%d/%s" % (out_dir, pi, h5_file) + job_h5_open = h5py.File(job_h5_file, "r") + num_seqs += job_h5_open[sad_stat].shape[0] + seq_len = job_h5_open[sad_stat].shape[1] + num_targets = job_h5_open[sad_stat].shape[-1] + job_h5_open.close() + + # initialize final h5 + final_h5_file = "%s/%s" % (out_dir, h5_file) + final_h5_open = h5py.File(final_h5_file, "w") + + # keep dict for string values + final_strings = {} + + job0_h5_file = "%s/job0/%s" % (out_dir, h5_file) + job0_h5_open = h5py.File(job0_h5_file, "r") + for key in job0_h5_open.keys(): + key_shape = list(job0_h5_open[key].shape) + key_shape[0] = num_seqs + key_shape = tuple(key_shape) + if job0_h5_open[key].dtype.char == "S": + final_strings[key] = [] + else: + final_h5_open.create_dataset( + key, shape=key_shape, dtype=job0_h5_open[key].dtype + ) + + # set values + si = 0 + for pi in range(num_procs): + # open job + job_h5_file = "%s/job%d/%s" % (out_dir, pi, h5_file) + job_h5_open = h5py.File(job_h5_file, "r") + + # append to final + for key in job_h5_open.keys(): + job_seqs = job_h5_open[key].shape[0] + if job_h5_open[key].dtype.char == "S": + final_strings[key] += list(job_h5_open[key]) + else: + final_h5_open[key][si : si + job_seqs] = job_h5_open[key] + + job_h5_open.close() + si += job_seqs + + # create final string datasets + for key in final_strings: + final_h5_open.create_dataset(key, data=np.array(final_strings[key], dtype="S")) + + final_h5_open.close() diff --git a/src/baskerville/helpers/utils.py b/src/baskerville/helpers/utils.py new file mode 100644 index 0000000..d1549dc --- /dev/null +++ b/src/baskerville/helpers/utils.py @@ -0,0 +1,20 @@ +import pickle + + +def load_extra_options(options_pkl_file, options): + """ + Args: + options_pkl_file: option file + options: existing options from command line + Returns: + options: updated options + """ + options_pkl = open(options_pkl_file, "rb") + new_options = pickle.load(options_pkl) + new_option_attrs = vars(new_options) + # Assuming 'options' is the existing options object + # Update the existing options with the new attributes + for attr_name, attr_value in new_option_attrs.items(): + setattr(options, attr_name, attr_value) + options_pkl.close() + return options diff --git a/src/baskerville/scripts/hound_snp.py b/src/baskerville/scripts/hound_snp.py index f913060..cf31a5d 100755 --- a/src/baskerville/scripts/hound_snp.py +++ b/src/baskerville/scripts/hound_snp.py @@ -25,6 +25,7 @@ upload_folder_gcs, download_rename_inputs, ) +from baskerville.helpers.utils import load_extra_options """ hound_snp.py @@ -211,25 +212,6 @@ def main(): shutil.rmtree(temp_dir) # clean up temp dir -def load_extra_options(options_pkl_file, options): - """ - Args: - options_pkl_file: option file - options: existing options from command line - Returns: - options: updated options - """ - options_pkl = open(options_pkl_file, "rb") - new_options = pickle.load(options_pkl) - new_option_attrs = vars(new_options) - # Assuming 'options' is the existing options object - # Update the existing options with the new attributes - for attr_name, attr_value in new_option_attrs.items(): - setattr(options, attr_name, attr_value) - options_pkl.close() - return options - - ################################################################################ # __main__ ################################################################################ diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index a00c843..36377fd 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -19,7 +19,7 @@ from natsort import natsorted import numpy as np import tensorflow as tf - +import gc from baskerville import blocks from baskerville import layers from baskerville import metrics @@ -219,12 +219,13 @@ def build_embed(self, conv_layer_i: int, batch_norm: bool = True): def build_ensemble(self, ensemble_rc: bool = False, ensemble_shifts=[0]): """Build ensemble of models computing on augmented input sequences.""" - if ensemble_rc or len(ensemble_shifts) > 1: + shift_bool = len(ensemble_shifts) > 1 or ensemble_shifts[0] != 0 + if ensemble_rc or shift_bool: # sequence input sequence = tf.keras.Input(shape=(self.seq_length, 4), name="sequence") sequences = [sequence] - if len(ensemble_shifts) > 1: + if shift_bool: # generate shifted sequences sequences = layers.EnsembleShift(ensemble_shifts)(sequences) @@ -397,6 +398,401 @@ def get_conv_weights(self, conv_layer_i=0): return weights def gradients( + self, + seq_1hot, + head_i=None, + target_slice=None, + pos_slice=None, + pos_mask=None, + pos_slice_denom=None, + pos_mask_denom=None, + chunk_size=None, + batch_size=1, + track_scale=1.0, + track_transform=1.0, + clip_soft=None, + pseudo_count=0.0, + no_transform=False, + use_mean=False, + use_ratio=False, + use_logodds=False, + subtract_avg=True, + input_gate=True, + smooth_grad=False, + n_samples=5, + sample_prob=0.875, + dtype="float16", + ): + """Compute input gradients for sequences (GPU-friendly).""" + + # start time + t0 = time.time() + + # choose model + if self.ensemble is not None: + model = self.ensemble + elif head_i is not None: + model = self.models[head_i] + else: + model = self.model + + # verify tensor shape(s) + seq_1hot = seq_1hot.astype("float32") + target_slice = np.array(target_slice).astype("int32") + pos_slice = np.array(pos_slice).astype("int32") + + # convert constants to tf tensors + track_scale = tf.constant(track_scale, dtype=tf.float32) + track_transform = tf.constant(track_transform, dtype=tf.float32) + if clip_soft is not None: + clip_soft = tf.constant(clip_soft, dtype=tf.float32) + pseudo_count = tf.constant(pseudo_count, dtype=tf.float32) + + if pos_mask is not None: + pos_mask = np.array(pos_mask).astype("float32") + + if use_ratio and pos_slice_denom is not None: + pos_slice_denom = np.array(pos_slice_denom).astype("int32") + + if pos_mask_denom is not None: + pos_mask_denom = np.array(pos_mask_denom).astype("float32") + + if len(seq_1hot.shape) < 3: + seq_1hot = seq_1hot[None, ...] + + if len(target_slice.shape) < 2: + target_slice = target_slice[None, ...] + + if len(pos_slice.shape) < 2: + pos_slice = pos_slice[None, ...] + + if pos_mask is not None and len(pos_mask.shape) < 2: + pos_mask = pos_mask[None, ...] + + if use_ratio and pos_slice_denom is not None and len(pos_slice_denom.shape) < 2: + pos_slice_denom = pos_slice_denom[None, ...] + + if pos_mask_denom is not None and len(pos_mask_denom.shape) < 2: + pos_mask_denom = pos_mask_denom[None, ...] + + # chunk parameters + num_chunks = 1 + if chunk_size is None: + chunk_size = seq_1hot.shape[0] + else: + num_chunks = int(np.ceil(seq_1hot.shape[0] / chunk_size)) + + # loop over chunks + grad_chunks = [] + for ci in range(num_chunks): + # collect chunk + seq_1hot_chunk = seq_1hot[ci * chunk_size : (ci + 1) * chunk_size, ...] + target_slice_chunk = target_slice[ + ci * chunk_size : (ci + 1) * chunk_size, ... + ] + pos_slice_chunk = pos_slice[ci * chunk_size : (ci + 1) * chunk_size, ...] + + pos_mask_chunk = None + if pos_mask is not None: + pos_mask_chunk = pos_mask[ci * chunk_size : (ci + 1) * chunk_size, ...] + + pos_slice_denom_chunk = None + pos_mask_denom_chunk = None + if use_ratio and pos_slice_denom is not None: + pos_slice_denom_chunk = pos_slice_denom[ + ci * chunk_size : (ci + 1) * chunk_size, ... + ] + + if pos_mask_denom is not None: + pos_mask_denom_chunk = pos_mask_denom[ + ci * chunk_size : (ci + 1) * chunk_size, ... + ] + + actual_chunk_size = seq_1hot_chunk.shape[0] + + # sample noisy (discrete) perturbations of the input pattern chunk + if smooth_grad: + seq_1hot_chunk_corrupted = np.repeat( + np.copy(seq_1hot_chunk), n_samples, axis=0 + ) + + for example_ix in range(seq_1hot_chunk.shape[0]): + for sample_ix in range(n_samples): + corrupt_index = np.nonzero( + np.random.rand(seq_1hot_chunk.shape[1]) >= sample_prob + )[0] + + rand_nt_index = np.random.choice( + [0, 1, 2, 3], size=(corrupt_index.shape[0],) + ) + + seq_1hot_chunk_corrupted[ + example_ix * n_samples + sample_ix, corrupt_index, : + ] = 0.0 + seq_1hot_chunk_corrupted[ + example_ix * n_samples + sample_ix, + corrupt_index, + rand_nt_index, + ] = 1.0 + + seq_1hot_chunk = seq_1hot_chunk_corrupted + target_slice_chunk = np.repeat( + np.copy(target_slice_chunk), n_samples, axis=0 + ) + pos_slice_chunk = np.repeat(np.copy(pos_slice_chunk), n_samples, axis=0) + + if pos_mask is not None: + pos_mask_chunk = np.repeat( + np.copy(pos_mask_chunk), n_samples, axis=0 + ) + + if use_ratio and pos_slice_denom is not None: + pos_slice_denom_chunk = np.repeat( + np.copy(pos_slice_denom_chunk), n_samples, axis=0 + ) + + if pos_mask_denom is not None: + pos_mask_denom_chunk = np.repeat( + np.copy(pos_mask_denom_chunk), n_samples, axis=0 + ) + + # convert to tf tensors + seq_1hot_chunk = tf.convert_to_tensor(seq_1hot_chunk, dtype=tf.float32) + target_slice_chunk = tf.convert_to_tensor( + target_slice_chunk, dtype=tf.int32 + ) + pos_slice_chunk = tf.convert_to_tensor(pos_slice_chunk, dtype=tf.int32) + + if pos_mask is not None: + pos_mask_chunk = tf.convert_to_tensor(pos_mask_chunk, dtype=tf.float32) + + if use_ratio and pos_slice_denom is not None: + pos_slice_denom_chunk = tf.convert_to_tensor( + pos_slice_denom_chunk, dtype=tf.int32 + ) + + if pos_mask_denom is not None: + pos_mask_denom_chunk = tf.convert_to_tensor( + pos_mask_denom_chunk, dtype=tf.float32 + ) + + # batching parameters + num_batches = int( + np.ceil( + actual_chunk_size * (n_samples if smooth_grad else 1) / batch_size + ) + ) + + # loop over batches + grad_batches = [] + for bi in range(num_batches): + # collect batch + seq_1hot_batch = seq_1hot_chunk[ + bi * batch_size : (bi + 1) * batch_size, ... + ] + target_slice_batch = target_slice_chunk[ + bi * batch_size : (bi + 1) * batch_size, ... + ] + pos_slice_batch = pos_slice_chunk[ + bi * batch_size : (bi + 1) * batch_size, ... + ] + + pos_mask_batch = None + if pos_mask is not None: + pos_mask_batch = pos_mask_chunk[ + bi * batch_size : (bi + 1) * batch_size, ... + ] + + pos_slice_denom_batch = None + pos_mask_denom_batch = None + if use_ratio and pos_slice_denom is not None: + pos_slice_denom_batch = pos_slice_denom_chunk[ + bi * batch_size : (bi + 1) * batch_size, ... + ] + + if pos_mask_denom is not None: + pos_mask_denom_batch = pos_mask_denom_chunk[ + bi * batch_size : (bi + 1) * batch_size, ... + ] + + grad_batch = ( + self.gradients_func( + model, + seq_1hot_batch, + target_slice_batch, + pos_slice_batch, + pos_mask_batch, + pos_slice_denom_batch, + pos_mask_denom_batch, + track_scale, + track_transform, + clip_soft, + pseudo_count, + no_transform, + use_mean, + use_ratio, + use_logodds, + subtract_avg, + input_gate, + ) + .numpy() + .astype(dtype) + ) + + grad_batches.append(grad_batch) + + # concat gradient batches + grads = np.concatenate(grad_batches, axis=0) + + # aggregate noisy gradient perturbations + if smooth_grad: + grads_smoothed = np.zeros( + (grads.shape[0] // n_samples, grads.shape[1], grads.shape[2]), + dtype="float32", + ) + + for example_ix in range(grads_smoothed.shape[0]): + for sample_ix in range(n_samples): + grads_smoothed[example_ix, ...] += grads[ + example_ix * n_samples + sample_ix, ... + ] + + grads = grads_smoothed / float(n_samples) + grads = grads.astype(dtype) + + grad_chunks.append(grads) + + # collect garbage + gc.collect() + + # concat gradient chunks + grads = np.concatenate(grad_chunks, axis=0) + + # aggregate and broadcast to original input pattern + if input_gate: + grads = np.sum(grads, axis=-1, keepdims=True) * seq_1hot + + print("Completed gradient computation in %ds" % (time.time() - t0)) + + return grads + + @tf.function + def gradients_func( + self, + model, + seq_1hot, + target_slice, + pos_slice, + pos_mask=None, + pos_slice_denom=None, + pos_mask_denom=True, + track_scale=1.0, + track_transform=1.0, + clip_soft=None, + pseudo_count=0.0, + no_transform=False, + use_mean=False, + use_ratio=False, + use_logodds=False, + subtract_avg=True, + input_gate=True, + ): + with tf.GradientTape() as tape: + tape.watch(seq_1hot) + + # predict + preds = tf.gather( + model(seq_1hot, training=False), target_slice, axis=-1, batch_dims=1 + ) + + if not no_transform: + # undo scale + preds = preds / track_scale + + # undo soft_clip + if clip_soft is not None: + preds = tf.where( + preds > clip_soft, (preds - clip_soft) ** 2 + clip_soft, preds + ) + + # undo sqrt + preds = preds ** (1.0 / track_transform) + + # aggregate over tracks (average) + preds = tf.reduce_mean(preds, axis=-1) + + # slice specified positions + preds_slice = tf.gather(preds, pos_slice, axis=-1, batch_dims=1) + if pos_mask is not None: + preds_slice = preds_slice * pos_mask + + # slice denominator positions + if use_ratio and pos_slice_denom is not None: + preds_slice_denom = tf.gather( + preds, pos_slice_denom, axis=-1, batch_dims=1 + ) + if pos_mask_denom is not None: + preds_slice_denom = preds_slice_denom * pos_mask_denom + + # aggregate over positions + if not use_mean: + preds_agg = tf.reduce_sum(preds_slice, axis=-1) + if use_ratio and pos_slice_denom is not None: + preds_agg_denom = tf.reduce_sum(preds_slice_denom, axis=-1) + else: + if pos_mask is not None: + preds_agg = tf.reduce_sum(preds_slice, axis=-1) / tf.reduce_sum( + pos_mask, axis=-1 + ) + else: + preds_agg = tf.reduce_mean(preds_slice, axis=-1) + + if use_ratio and pos_slice_denom is not None: + if pos_mask_denom is not None: + preds_agg_denom = tf.reduce_sum( + preds_slice_denom, axis=-1 + ) / tf.reduce_sum(pos_mask_denom, axis=-1) + else: + preds_agg_denom = tf.reduce_mean(preds_slice_denom, axis=-1) + + # compute final statistic to take gradient of + if no_transform: + score_ratios = preds_agg + elif not use_ratio: + score_ratios = tf.math.log(preds_agg + pseudo_count + 1e-6) + else: + if not use_logodds: + score_ratios = tf.math.log( + (preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count) + + 1e-6 + ) + else: + score_ratios = tf.math.log( + ((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count)) + / ( + 1.0 + - ( + (preds_agg + pseudo_count) + / (preds_agg_denom + pseudo_count) + ) + ) + + 1e-6 + ) + + # compute gradient + grads = tape.gradient(score_ratios, seq_1hot) + + # zero mean each position + if subtract_avg: + grads = grads - tf.reduce_mean(grads, axis=-1, keepdims=True) + + # multiply by input + if input_gate: + grads = grads * seq_1hot + + return grads + + def gradients_orig( self, seq_1hot, head_i=None, pos_slice=None, batch_size=8, dtype="float16" ): """Compute input gradients for each task. @@ -463,7 +859,7 @@ def gradients( return grads @tf.function - def gradients_func(self, model, seq_1hot, pos_slice): + def gradients_func_orig(self, model, seq_1hot, pos_slice): """Compute input gradients for each task. Args: