From 6573c0c84e10ce1cd311b2bd3ef67dcd2136cf49 Mon Sep 17 00:00:00 2001 From: Johannes Linder Date: Fri, 20 Sep 2024 14:04:40 -0700 Subject: [PATCH 1/7] Revision updates (untransform_old flag, removed smoothgrad) --- src/baskerville/scripts/hound_data.py | 19 +- src/baskerville/scripts/hound_data_read.py | 136 ++++++++---- src/baskerville/seqnn.py | 230 ++++----------------- 3 files changed, 151 insertions(+), 234 deletions(-) diff --git a/src/baskerville/scripts/hound_data.py b/src/baskerville/scripts/hound_data.py index a95284a..8f1d720 100755 --- a/src/baskerville/scripts/hound_data.py +++ b/src/baskerville/scripts/hound_data.py @@ -81,7 +81,9 @@ def main(): help="Generate cross fold split [Default: %default]", ) parser.add_option( - "-g", dest="gaps_file", help="Genome assembly gaps BED [Default: %default]" + "-g", + dest="gaps_file", + help="Genome assembly gaps BED [Default: %default]" ) parser.add_option( "-i", @@ -194,7 +196,11 @@ def main(): type="str", help="Proportion of the data for testing [Default: %default]", ) - parser.add_option("-u", dest="umap_bed", help="Unmappable regions in BED format") + parser.add_option( + "-u", + dest="umap_bed", + help="Unmappable regions in BED format" + ) parser.add_option( "--umap_t", dest="umap_t", @@ -230,6 +236,13 @@ def main(): type="str", help="Proportion of the data for validation [Default: %default]", ) + parser.add_option( + "--transform_old", + dest="transform_old", + default=False, + action="store_true", + help="Apply old target transforms [Default: %default]", + ) (options, args) = parser.parse_args() if len(args) != 2: @@ -493,6 +506,8 @@ def main(): cmd += " -b %s" % options.blacklist_bed if options.interp_nan: cmd += " -i" + if options.transform_old: + cmd += " --transform_old" cmd += " %s" % genome_cov_file cmd += " %s" % seqs_bed_file cmd += " %s" % seqs_cov_file diff --git a/src/baskerville/scripts/hound_data_read.py b/src/baskerville/scripts/hound_data_read.py index 5b6ec35..aad44b9 100755 --- a/src/baskerville/scripts/hound_data_read.py +++ b/src/baskerville/scripts/hound_data_read.py @@ -108,6 +108,13 @@ def main(): type="int", help="Average pooling width [Default: %default]", ) + parser.add_option( + "--transform_old", + dest="transform_old", + default=False, + action="store_true", + help="Apply old target transforms [Default: %default]", + ) (options, args) = parser.parse_args() if len(args) != 3: @@ -180,49 +187,92 @@ def main(): # crop if options.crop_bp > 0: seq_cov_nt = seq_cov_nt[options.crop_bp : -options.crop_bp] - - # scale - seq_cov_nt = options.scale * seq_cov_nt - - # sum pool - seq_cov = seq_cov_nt.reshape(target_length, options.pool_width) - if options.sum_stat == "sum": - seq_cov = seq_cov.sum(axis=1, dtype="float32") - elif options.sum_stat == "sum_sqrt": - seq_cov = seq_cov.sum(axis=1, dtype="float32") - seq_cov = -1 + np.sqrt(1 + seq_cov) - elif options.sum_stat == "sum_exp75": - seq_cov = seq_cov.sum(axis=1, dtype="float32") - seq_cov = -1 + (1 + seq_cov) ** 0.75 - elif options.sum_stat in ["mean", "avg"]: - seq_cov = seq_cov.mean(axis=1, dtype="float32") - elif options.sum_stat in ["mean_sqrt", "avg_sqrt"]: - seq_cov = seq_cov.mean(axis=1, dtype="float32") - seq_cov = -1 + np.sqrt(1 + seq_cov) - elif options.sum_stat == "median": - seq_cov = seq_cov.median(axis=1) - elif options.sum_stat == "max": - seq_cov = seq_cov.max(axis=1) - elif options.sum_stat == "peak": - seq_cov = seq_cov.mean(axis=1, dtype="float32") - seq_cov = np.clip(np.sqrt(seq_cov * 4), 0, 1) - else: - print( - 'ERROR: Unrecognized summary statistic "%s".' % options.sum_stat, - file=sys.stderr, - ) - exit(1) - - # clip - if options.clip_soft is not None: - clip_mask = seq_cov > options.clip_soft - seq_cov[clip_mask] = ( - options.clip_soft - - 1 - + np.sqrt(seq_cov[clip_mask] - options.clip_soft + 1) - ) - if options.clip is not None: - seq_cov = np.clip(seq_cov, -options.clip, options.clip) + + #apply original transform (from borzoi manuscript) + if options.transform_old: + # sum pool + seq_cov = seq_cov_nt.reshape(target_length, options.pool_width) + if options.sum_stat == 'sum': + seq_cov = seq_cov.sum(axis=1, dtype='float32') + elif options.sum_stat == 'sum_sqrt': + seq_cov = seq_cov.sum(axis=1, dtype='float32') + seq_cov = seq_cov**0.75 + elif options.sum_stat in ['mean', 'avg']: + seq_cov = seq_cov.mean(axis=1, dtype='float32') + elif options.sum_stat in ['mean_sqrt', 'avg_sqrt']: + seq_cov = seq_cov.mean(axis=1, dtype='float32') + seq_cov = seq_cov**0.75 + elif options.sum_stat == 'median': + seq_cov = seq_cov.median(axis=1) + elif options.sum_stat == 'max': + seq_cov = seq_cov.max(axis=1) + elif options.sum_stat == 'peak': + seq_cov = seq_cov.mean(axis=1, dtype='float32') + seq_cov = np.clip(np.sqrt(seq_cov*4), 0, 1) + else: + print( + 'ERROR: Unrecognized summary statistic "%s".' % options.sum_stat, + file=sys.stderr + ) + exit(1) + + # clip + if options.clip_soft is not None: + clip_mask = seq_cov > options.clip_soft + seq_cov[clip_mask] = ( + options.clip_soft + + np.sqrt(seq_cov[clip_mask] - options.clip_soft) + ) + if options.clip is not None: + seq_cov = np.clip(seq_cov, -options.clip, options.clip) + + # scale + seq_cov = options.scale * seq_cov + + else : #apply new (updated) transform + + # scale + seq_cov_nt = options.scale * seq_cov_nt + + # sum pool + seq_cov = seq_cov_nt.reshape(target_length, options.pool_width) + if options.sum_stat == "sum": + seq_cov = seq_cov.sum(axis=1, dtype="float32") + elif options.sum_stat == "sum_sqrt": + seq_cov = seq_cov.sum(axis=1, dtype="float32") + seq_cov = -1 + np.sqrt(1 + seq_cov) + elif options.sum_stat == "sum_exp75": + seq_cov = seq_cov.sum(axis=1, dtype="float32") + seq_cov = -1 + (1 + seq_cov) ** 0.75 + elif options.sum_stat in ["mean", "avg"]: + seq_cov = seq_cov.mean(axis=1, dtype="float32") + elif options.sum_stat in ["mean_sqrt", "avg_sqrt"]: + seq_cov = seq_cov.mean(axis=1, dtype="float32") + seq_cov = -1 + np.sqrt(1 + seq_cov) + elif options.sum_stat == "median": + seq_cov = seq_cov.median(axis=1) + elif options.sum_stat == "max": + seq_cov = seq_cov.max(axis=1) + elif options.sum_stat == "peak": + seq_cov = seq_cov.mean(axis=1, dtype="float32") + seq_cov = np.clip(np.sqrt(seq_cov * 4), 0, 1) + else: + print( + 'ERROR: Unrecognized summary statistic "%s".' % options.sum_stat, + file=sys.stderr, + ) + exit(1) + + # clip + if options.clip_soft is not None: + clip_mask = seq_cov > options.clip_soft + seq_cov[clip_mask] = ( + options.clip_soft + - 1 + + np.sqrt(seq_cov[clip_mask] - options.clip_soft + 1) + ) + if options.clip is not None: + seq_cov = np.clip(seq_cov, -options.clip, options.clip) # clip float16 min/max seq_cov = np.clip(seq_cov, np.finfo(np.float16).min, np.finfo(np.float16).max) diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 4a46772..c2c83b5 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -410,22 +410,20 @@ def gradients( pos_mask_denom=None, chunk_size=None, batch_size=1, - track_scale=1.0, - track_transform=1.0, + track_scale=1., + track_transform=1., clip_soft=None, - pseudo_count=0.0, - no_transform=False, + pseudo_count=0., + untransform_old=False, + no_untransform=False, use_mean=False, use_ratio=False, use_logodds=False, subtract_avg=True, input_gate=True, - smooth_grad=False, - n_samples=5, - sample_prob=0.875, dtype="float16", ): - """Compute input gradients for sequences (GPU-friendly).""" + """Compute input gradients for sequences.""" # start time t0 = time.time() @@ -512,52 +510,6 @@ def gradients( actual_chunk_size = seq_1hot_chunk.shape[0] - # sample noisy (discrete) perturbations of the input pattern chunk - if smooth_grad: - seq_1hot_chunk_corrupted = np.repeat( - np.copy(seq_1hot_chunk), n_samples, axis=0 - ) - - for example_ix in range(seq_1hot_chunk.shape[0]): - for sample_ix in range(n_samples): - corrupt_index = np.nonzero( - np.random.rand(seq_1hot_chunk.shape[1]) >= sample_prob - )[0] - - rand_nt_index = np.random.choice( - [0, 1, 2, 3], size=(corrupt_index.shape[0],) - ) - - seq_1hot_chunk_corrupted[ - example_ix * n_samples + sample_ix, corrupt_index, : - ] = 0.0 - seq_1hot_chunk_corrupted[ - example_ix * n_samples + sample_ix, - corrupt_index, - rand_nt_index, - ] = 1.0 - - seq_1hot_chunk = seq_1hot_chunk_corrupted - target_slice_chunk = np.repeat( - np.copy(target_slice_chunk), n_samples, axis=0 - ) - pos_slice_chunk = np.repeat(np.copy(pos_slice_chunk), n_samples, axis=0) - - if pos_mask is not None: - pos_mask_chunk = np.repeat( - np.copy(pos_mask_chunk), n_samples, axis=0 - ) - - if use_ratio and pos_slice_denom is not None: - pos_slice_denom_chunk = np.repeat( - np.copy(pos_slice_denom_chunk), n_samples, axis=0 - ) - - if pos_mask_denom is not None: - pos_mask_denom_chunk = np.repeat( - np.copy(pos_mask_denom_chunk), n_samples, axis=0 - ) - # convert to tf tensors seq_1hot_chunk = tf.convert_to_tensor(seq_1hot_chunk, dtype=tf.float32) target_slice_chunk = tf.convert_to_tensor( @@ -581,7 +533,7 @@ def gradients( # batching parameters num_batches = int( np.ceil( - actual_chunk_size * (n_samples if smooth_grad else 1) / batch_size + actual_chunk_size / batch_size ) ) @@ -630,7 +582,8 @@ def gradients( track_transform, clip_soft, pseudo_count, - no_transform, + untransform_old, + no_untransform, use_mean, use_ratio, use_logodds, @@ -646,22 +599,6 @@ def gradients( # concat gradient batches grads = np.concatenate(grad_batches, axis=0) - # aggregate noisy gradient perturbations - if smooth_grad: - grads_smoothed = np.zeros( - (grads.shape[0] // n_samples, grads.shape[1], grads.shape[2]), - dtype="float32", - ) - - for example_ix in range(grads_smoothed.shape[0]): - for sample_ix in range(n_samples): - grads_smoothed[example_ix, ...] += grads[ - example_ix * n_samples + sample_ix, ... - ] - - grads = grads_smoothed / float(n_samples) - grads = grads.astype(dtype) - grad_chunks.append(grads) # collect garbage @@ -688,17 +625,19 @@ def gradients_func( pos_mask=None, pos_slice_denom=None, pos_mask_denom=True, - track_scale=1.0, - track_transform=1.0, + track_scale=1., + track_transform=1., clip_soft=None, - pseudo_count=0.0, - no_transform=False, + pseudo_count=0., + untransform_old=False, + no_untransform=False, use_mean=False, use_ratio=False, use_logodds=False, subtract_avg=True, input_gate=True, ): + """Compute gradient of the model prediction with respect to the input sequence.""" with tf.GradientTape() as tape: tape.watch(seq_1hot) @@ -707,18 +646,32 @@ def gradients_func( model(seq_1hot, training=False), target_slice, axis=-1, batch_dims=1 ) - if not no_transform: - # undo scale - preds = preds / track_scale + if not no_untransform: + if untransform_old: + # undo scale + preds = preds / track_scale - # undo soft_clip - if clip_soft is not None: - preds = tf.where( - preds > clip_soft, (preds - clip_soft) ** 2 + clip_soft, preds - ) + # undo clip_soft + if clip_soft is not None: + preds = tf.where( + preds > clip_soft, (preds - clip_soft) ** 2 + clip_soft, preds + ) - # undo sqrt - preds = preds ** (1.0 / track_transform) + # undo sqrt + preds = preds ** (1. / track_transform) + else: + # undo clip_soft + if clip_soft is not None: + preds = tf.where( + preds > clip_soft, (preds - clip_soft + 1) ** 2 + clip_soft - 1, preds + ) + + # undo sqrt + preds = -1 + (preds + 1) ** (1. / track_transform) + + # scale + preds = preds / track_scale + # aggregate over tracks (average) preds = tf.reduce_mean(preds, axis=-1) @@ -758,7 +711,7 @@ def gradients_func( preds_agg_denom = tf.reduce_mean(preds_slice_denom, axis=-1) # compute final statistic to take gradient of - if no_transform: + if no_untransform: score_ratios = preds_agg elif not use_ratio: score_ratios = tf.math.log(preds_agg + pseudo_count + 1e-6) @@ -772,7 +725,7 @@ def gradients_func( score_ratios = tf.math.log( ((preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count)) / ( - 1.0 + 1. - ( (preds_agg + pseudo_count) / (preds_agg_denom + pseudo_count) @@ -794,107 +747,6 @@ def gradients_func( return grads - def gradients_orig( - self, seq_1hot, head_i=None, pos_slice=None, batch_size=8, dtype="float16" - ): - """Compute input gradients for each task. - - Args: - seq_1hot (np.array): 1-hot encoded sequence. - head_i (int): Model head index. - pos_slice ([int]): Sequence positions to consider. - batch_size (int): number of tasks to compute gradients for at once. - dtype: Returned data type. - Returns: - Gradients for each task. - """ - # choose model - if self.ensemble is not None: - model = self.ensemble - elif head_i is not None: - model = self.models[head_i] - else: - model = self.model - - # verify tensor shape - seq_1hot = seq_1hot.astype("float32") - seq_1hot = tf.convert_to_tensor(seq_1hot, dtype=tf.float32) - if len(seq_1hot.shape) < 3: - seq_1hot = tf.expand_dims(seq_1hot, axis=0) - - # batching parameters - num_targets = model.output_shape[-1] - num_batches = int(np.ceil(num_targets / batch_size)) - - ti_start = 0 - grads = [] - for bi in range(num_batches): - # sequence input - sequence = tf.keras.Input(shape=(self.seq_length, 4), name="sequence") - - # predict - predictions = model(sequence) - - # slice - ti_end = min(num_targets, ti_start + batch_size) - target_slice = np.arange(ti_start, ti_end) - predictions_slice = tf.gather(predictions, target_slice, axis=-1) - - # replace model - model_batch = tf.keras.Model(inputs=sequence, outputs=predictions_slice) - - # compute gradients - t0 = time.time() - grads_batch = self.gradients_func(model_batch, seq_1hot, pos_slice) - print("Batch gradient computation in %ds" % (time.time() - t0)) - - # convert numpy dtype - grads_batch = grads_batch.numpy().astype(dtype) - grads.append(grads_batch) - - # next batch - ti_start += batch_size - - # concat target batches - grads = np.concatenate(grads, axis=-1) - - return grads - - @tf.function - def gradients_func_orig(self, model, seq_1hot, pos_slice): - """Compute input gradients for each task. - - Args: - model (tf.keras.Model): Model to compute gradients for. - seq_1hot (tf.Tensor): 1-hot encoded sequence. - pos_slice ([int]): Sequence positions to consider. - - Returns: - grads (tf.Tensor): Gradients for each task. - """ - with tf.GradientTape() as tape: - tape.watch(seq_1hot) - - # predict - preds = model(seq_1hot, training=False) - - if pos_slice is not None: - # slice specified positions - preds = tf.gather(preds, pos_slice, axis=-2) - - # sum across positions - preds = tf.reduce_sum(preds, axis=-2) - - # compute jacboian - grads = tape.jacobian(preds, seq_1hot) - grads = tf.squeeze(grads) - grads = tf.transpose(grads, [1, 2, 0]) - - # zero mean each position - grads = grads - tf.reduce_mean(grads, axis=-2, keepdims=True) - - return grads - def num_targets(self, head_i=None): """Return number of targets.""" if head_i is None: From 5203ce54250ed79e9b2a23d7f8e2791675945285 Mon Sep 17 00:00:00 2001 From: Johannes Linder Date: Mon, 30 Sep 2024 09:24:24 -0700 Subject: [PATCH 2/7] Added missing stream script and pyfaidx dependency. --- pyproject.toml | 1 + src/baskerville/stream.py | 211 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 src/baskerville/stream.py diff --git a/pyproject.toml b/pyproject.toml index a8f08d4..c5e83ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "tabulate~=0.8.10", "tensorflow~=2.15.0", "tqdm~=4.65.0", + "pyfaidx~=0.7.1", ] [project.optional-dependencies] diff --git a/src/baskerville/stream.py b/src/baskerville/stream.py new file mode 100644 index 0000000..79a235c --- /dev/null +++ b/src/baskerville/stream.py @@ -0,0 +1,211 @@ +# Copyright 2017 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from __future__ import print_function +import pdb + +import numpy as np +import tensorflow as tf + +from baskerville import dna + + +class PredStreamGen: + """ Interface to acquire predictions via a buffered stream mechanism + rather than getting them all at once and using excessive memory. + Accepts generator and constructs stream batches from it. """ + def __init__(self, model, seqs_gen, batch_size, stream_seqs=32, verbose=False): + self.model = model + self.seqs_gen = seqs_gen + self.stream_seqs = stream_seqs + self.batch_size = batch_size + self.verbose = verbose + + self.stream_start = 0 + self.stream_end = 0 + self.stream_preds = [] + + + def __getitem__(self, i): + # acquire predictions, if needed + if i >= self.stream_end: + # update start + self.stream_start = self.stream_end + + # predict + del self.stream_preds + self.stream_preds = self.model.predict(self.make_dataset()) + + # update end + self.stream_end = self.stream_start + self.stream_preds.shape[0] + + if self.verbose: + print('Predicting %d-%d' % (self.stream_start, self.stream_end), flush=True) + + return self.stream_preds[i - self.stream_start] + + def make_dataset(self): + """ Construct Dataset object for this stream chunk. """ + seqs_1hot = [] + stream_end = self.stream_start+self.stream_seqs + for si in range(self.stream_start, stream_end): + try: + seqs_1hot.append(self.seqs_gen.__next__()) + except StopIteration: + continue + + seqs_1hot = np.array(seqs_1hot) + + dataset = tf.data.Dataset.from_tensor_slices((seqs_1hot,)) + dataset = dataset.batch(self.batch_size) + return dataset + + +class PredStreamIter: + """ Interface to acquire predictions via a buffered stream mechanism + rather than getting them all at once and using excessive memory. + Accepts iterator and constructs stream batches from it. + [I don't recall whether I've ever gotten this one working.""" + def __init__(self, model, dataset_iter, stream_seqs=128, verbose=False): + self.model = model + self.dataset_iter = dataset_iter + self.stream_seqs = stream_seqs + self.verbose = verbose + + self.stream_start = 0 + self.stream_end = 0 + + + def __getitem__(self, i): + # acquire predictions, if needed + if i >= self.stream_end: + # update start + self.stream_start = self.stream_end + + if self.verbose: + print('Predicting from %d' % self.stream_start, flush=True) + + # predict + self.stream_preds = self.model.predict(self.fetch_batch()) + + # update end + self.stream_end = self.stream_start + self.stream_preds.shape[0] + + return self.stream_preds[i - self.stream_start] + + def fetch_batch(self): + """Fetch a batch of data from the dataset iterator.""" + x = [next(self.dataset_iter)] + while x[-1] and len(x) < self.stream_seqs: + x.append(next(self.dataset_iter)) + return x + + +class PredStreamSonnet: + """ Interface to acquire predictions via a buffered stream mechanism + rather than getting them all at once and using excessive memory. + Accepts generator and constructs stream batches from it. """ + def __init__(self, model, seqs_gen, batch_size=4, stream_size=32, + rc=False, shifts=[0], slice_center=None, + species='human', return_augm=False, verbose=False): + self.model = model + self.seqs_gen = seqs_gen + self.batch_size = batch_size + self.stream_size = stream_size + self.rc = rc + self.shifts = shifts + self.ensembled = len(self.shifts) + int(self.rc)*len(self.shifts) + self.slice_center = slice_center + self.species = species + self.verbose = verbose + self.return_augm = return_augm + + self.stream_start = 0 + self.stream_end = 0 + + + def __getitem__(self, i): + # acquire predictions, if needed + if i >= self.stream_end: + # update start + self.stream_start = self.stream_end + + if self.verbose: + print('Predicting from %d' % self.stream_start, flush=True) + + # get next sequences + seqs_1hot = self.next_seqs() + + # predict stream + stream_preds = [] + si = 0 + while si < seqs_1hot.shape[0]: + spreds = self.model.predict_on_batch(seqs_1hot[si:si+self.batch_size]) + spreds = spreds[self.species].numpy() + stream_preds.append(spreds) + si += self.batch_size + stream_preds = np.concatenate(stream_preds, axis=0) + + # slice center + if self.slice_center is not None: + _, seq_len, _ = stream_preds.shape + mid_pos = seq_len // 2 + slice_start = mid_pos - self.slice_center//2 + slice_end = slice_start + self.slice_center + stream_preds = stream_preds[:,slice_start:slice_end,:] + + # reshape to expose augmentations + ens_seqs, seq_len, num_targets = stream_preds.shape + num_seqs = ens_seqs // self.ensembled + stream_preds = np.reshape(stream_preds, + (num_seqs, self.ensembled, seq_len, num_targets)) + + if self.return_augm: + # move augmentations to the back + self.stream_preds = np.transpose(stream_preds, [0,2,3,1]) + else: + # average augmentations + self.stream_preds = stream_preds.mean(axis=1) + + # update end + self.stream_end = self.stream_start + self.stream_preds.shape[0] + + return self.stream_preds[i - self.stream_start] + + def next_seqs(self): + """ Construct array of sequences for this stream chunk. """ + + # extract next sequences from generator + seqs_1hot = [] + stream_end = self.stream_start+self.stream_size + for si in range(self.stream_start, stream_end): + try: + seqs_1hot.append(self.seqs_gen.__next__()) + except StopIteration: + continue + + # initialize ensemble + seqs_1hot_ens = [] + + # add rc/shifts + for seq_1hot in seqs_1hot: + for shift in self.shifts: + seq_1hot_aug = dna_io.hot1_augment(seq_1hot, shift=shift) + seqs_1hot_ens.append(seq_1hot_aug) + if self.rc: + seq_1hot_aug = dna_io.hot1_rc(seq_1hot_aug) + seqs_1hot_ens.append(seq_1hot_aug) + + seqs_1hot_ens = np.array(seqs_1hot_ens, dtype='float32') + return seqs_1hot_ens From 8c872fba433d6699f1edc4e9923b5ea4ea355792 Mon Sep 17 00:00:00 2001 From: Johannes Linder Date: Mon, 30 Sep 2024 15:19:30 -0700 Subject: [PATCH 3/7] Updated environment variables. --- env_vars.sh | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100755 env_vars.sh diff --git a/env_vars.sh b/env_vars.sh new file mode 100755 index 0000000..6a78108 --- /dev/null +++ b/env_vars.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# set these variables before running the script +LOCAL_BASKERVILLE_PATH="/home/jlinder/baskerville" +LOCAL_USER="jlinder" + +# create env_vars sh scripts in local conda env +mkdir -p "$CONDA_PREFIX/etc/conda/activate.d" +mkdir -p "$CONDA_PREFIX/etc/conda/deactivate.d" + +file_vars_act="$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh" +if ! [ -e $file_vars_act ]; then + echo '#!/bin/sh' > $file_vars_act +fi + +file_vars_deact="$CONDA_PREFIX/etc/conda/deactivate.d/env_vars.sh" +if ! [ -e $file_vars_deact ]; then + echo '#!/bin/sh' > $file_vars_deact +fi + +# append env variable exports to /activate.d/env_vars.sh +echo "export BASKERVILLE_DIR=$LOCAL_BASKERVILLE_PATH" >> $file_vars_act +echo 'export PATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PATH' >> $file_vars_act +echo 'export PYTHONPATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PYTHONPATH' >> $file_vars_act + +echo "export BASKERVILLE_CONDA=/home/$LOCAL_USER/anaconda3/etc/profile.d/conda.sh" >> $file_vars_act + +# append env variable unsets to /deactivate.d/env_vars.sh +echo 'unset BASKERVILLE_DIR' >> $file_vars_deact +echo 'unset BASKERVILLE_CONDA' >> $file_vars_deact + +# finally activate env variables +source $file_vars_act From 68ddb76de7d08a5422522addb86e4745779908b5 Mon Sep 17 00:00:00 2001 From: johli Date: Tue, 1 Oct 2024 15:33:11 -0700 Subject: [PATCH 4/7] Update README.md --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 727cb17..98f010d 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,18 @@ Documentation page: https://calico.github.io/baskerville/index.html `cd baskerville` `pip install .` +To set up the required environment variables: +`cd baskerville` +`conda activate ` +`./env_vars.sh` + +Alternatively, the environment variables can be set manually: +```sh +export BASKERVILLE_DIR=/home//baskerville +export PATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PATH +export PYTHONPATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PYTHONPATH +``` + --- #### Contacts From cd13ef5415bcb5b2692c01f4cbf21b205160b05e Mon Sep 17 00:00:00 2001 From: johli Date: Fri, 4 Oct 2024 10:25:21 -0700 Subject: [PATCH 5/7] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 98f010d..4dca2f7 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ To set up the required environment variables: `conda activate ` `./env_vars.sh` +*Note:* Change the two lines of code at the top of './env_vars.sh' to your username and local path. + Alternatively, the environment variables can be set manually: ```sh export BASKERVILLE_DIR=/home//baskerville From ac8813969e98c1ea08b0f57ce819c8877d5235df Mon Sep 17 00:00:00 2001 From: Johannes Linder Date: Sun, 6 Oct 2024 20:20:12 -0700 Subject: [PATCH 6/7] Added backward support for old gene output_slice function. --- src/baskerville/gene.py | 53 ++++++++++++++++++++-- src/baskerville/scripts/hound_data_read.py | 2 +- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/src/baskerville/gene.py b/src/baskerville/gene.py index 0645385..8baa4c0 100644 --- a/src/baskerville/gene.py +++ b/src/baskerville/gene.py @@ -52,10 +52,11 @@ class Gene: """Class for managing genes in an isoform-agnostic way, taking the union of exons across isoforms.""" - def __init__(self, chrom, strand, kv): + def __init__(self, chrom, strand, kv, name=None): self.chrom = chrom self.strand = strand self.kv = kv + self.name = name self.exons = IntervalTree() def add_exon(self, start, end): @@ -77,10 +78,53 @@ def span(self): exon_starts = [exon.begin for exon in self.exons] exon_ends = [exon.end for exon in self.exons] return min(exon_starts), max(exon_ends) + + def output_slice_old(self, seq_start, seq_len, model_stride, span=False): + gene_slice = [] + + if span: + gene_start, gene_end = self.span() + + # clip left boundaries + gene_seq_start = max(0, gene_start - seq_start) + gene_seq_end = max(0, gene_end - seq_start) + + # requires >50% overlap + slice_start = int(np.round(gene_seq_start / model_stride)) + slice_end = int(np.round(gene_seq_end / model_stride)) + + # clip right boundaries + slice_max = int(seq_len/model_stride) + slice_start = min(slice_start, slice_max) + slice_end = min(slice_end, slice_max) + + gene_slice = range(slice_start, slice_end) + + else: + for exon in self.get_exons(): + # clip left boundaries + exon_seq_start = max(0, exon.begin - seq_start) + exon_seq_end = max(0, exon.end - seq_start) + + # requires >50% overlap + slice_start = int(np.round(exon_seq_start / model_stride)) + slice_end = int(np.round(exon_seq_end / model_stride)) + + # clip right boundaries + slice_max = int(seq_len/model_stride) + slice_start = min(slice_start, slice_max) + slice_end = min(slice_end, slice_max) + + gene_slice.extend(range(slice_start, slice_end)) + + return np.array(gene_slice) def output_slice( - self, seq_start, seq_len, model_stride, span=False, majority_overlap=False + self, seq_start, seq_len, model_stride, span=False, majority_overlap=False, old_version=False ): + if old_version : + return self.output_slice_old(seq_start, seq_len, model_stride, span=span) + gene_slice = [] def clip_boundaries(slice_start, slice_end): @@ -162,10 +206,13 @@ def read_gtf(self, gtf_file): strand = a[6] kv = gtf_kv(a[8]) gene_id = kv["gene_id"] + gene_name = None + if 'gene_name' in kv: + gene_name = kv['gene_name'] # initialize gene if gene_id not in self.genes: - self.genes[gene_id] = Gene(chrom, strand, kv) + self.genes[gene_id] = Gene(chrom, strand, kv, gene_name) # add exon self.genes[gene_id].add_exon(start - 1, end) diff --git a/src/baskerville/scripts/hound_data_read.py b/src/baskerville/scripts/hound_data_read.py index aad44b9..20b7471 100755 --- a/src/baskerville/scripts/hound_data_read.py +++ b/src/baskerville/scripts/hound_data_read.py @@ -188,7 +188,7 @@ def main(): if options.crop_bp > 0: seq_cov_nt = seq_cov_nt[options.crop_bp : -options.crop_bp] - #apply original transform (from borzoi manuscript) + # apply original transform (from borzoi manuscript) if options.transform_old: # sum pool seq_cov = seq_cov_nt.reshape(target_length, options.pool_width) From 195a66da5965991c5b64306550a9729841577534 Mon Sep 17 00:00:00 2001 From: Johannes Linder Date: Mon, 7 Oct 2024 19:29:57 -0700 Subject: [PATCH 7/7] Removed unused classes in stream script. Cleaned up env_vars shell script. --- README.md | 4 +- env_vars.sh | 4 +- src/baskerville/stream.py | 139 -------------------------------------- 3 files changed, 5 insertions(+), 142 deletions(-) diff --git a/README.md b/README.md index 4dca2f7..0d144c8 100644 --- a/README.md +++ b/README.md @@ -27,13 +27,15 @@ To set up the required environment variables: `conda activate ` `./env_vars.sh` -*Note:* Change the two lines of code at the top of './env_vars.sh' to your username and local path. +*Note:* Change the two lines of code at the top of './env_vars.sh' to the correct local paths. Alternatively, the environment variables can be set manually: ```sh export BASKERVILLE_DIR=/home//baskerville export PATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PATH export PYTHONPATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PYTHONPATH + +export BASKERVILLE_CONDA=/home//anaconda3/etc/profile.d/conda.sh ``` --- diff --git a/env_vars.sh b/env_vars.sh index 6a78108..1d13764 100755 --- a/env_vars.sh +++ b/env_vars.sh @@ -2,7 +2,7 @@ # set these variables before running the script LOCAL_BASKERVILLE_PATH="/home/jlinder/baskerville" -LOCAL_USER="jlinder" +LOCAL_CONDA_PATH="/home/jlinder/anaconda3/etc/profile.d/conda.sh" # create env_vars sh scripts in local conda env mkdir -p "$CONDA_PREFIX/etc/conda/activate.d" @@ -23,7 +23,7 @@ echo "export BASKERVILLE_DIR=$LOCAL_BASKERVILLE_PATH" >> $file_vars_act echo 'export PATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PATH' >> $file_vars_act echo 'export PYTHONPATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PYTHONPATH' >> $file_vars_act -echo "export BASKERVILLE_CONDA=/home/$LOCAL_USER/anaconda3/etc/profile.d/conda.sh" >> $file_vars_act +echo "export BASKERVILLE_CONDA=$LOCAL_CONDA_PATH" >> $file_vars_act # append env variable unsets to /deactivate.d/env_vars.sh echo 'unset BASKERVILLE_DIR' >> $file_vars_deact diff --git a/src/baskerville/stream.py b/src/baskerville/stream.py index 79a235c..14827b8 100644 --- a/src/baskerville/stream.py +++ b/src/baskerville/stream.py @@ -70,142 +70,3 @@ def make_dataset(self): dataset = tf.data.Dataset.from_tensor_slices((seqs_1hot,)) dataset = dataset.batch(self.batch_size) return dataset - - -class PredStreamIter: - """ Interface to acquire predictions via a buffered stream mechanism - rather than getting them all at once and using excessive memory. - Accepts iterator and constructs stream batches from it. - [I don't recall whether I've ever gotten this one working.""" - def __init__(self, model, dataset_iter, stream_seqs=128, verbose=False): - self.model = model - self.dataset_iter = dataset_iter - self.stream_seqs = stream_seqs - self.verbose = verbose - - self.stream_start = 0 - self.stream_end = 0 - - - def __getitem__(self, i): - # acquire predictions, if needed - if i >= self.stream_end: - # update start - self.stream_start = self.stream_end - - if self.verbose: - print('Predicting from %d' % self.stream_start, flush=True) - - # predict - self.stream_preds = self.model.predict(self.fetch_batch()) - - # update end - self.stream_end = self.stream_start + self.stream_preds.shape[0] - - return self.stream_preds[i - self.stream_start] - - def fetch_batch(self): - """Fetch a batch of data from the dataset iterator.""" - x = [next(self.dataset_iter)] - while x[-1] and len(x) < self.stream_seqs: - x.append(next(self.dataset_iter)) - return x - - -class PredStreamSonnet: - """ Interface to acquire predictions via a buffered stream mechanism - rather than getting them all at once and using excessive memory. - Accepts generator and constructs stream batches from it. """ - def __init__(self, model, seqs_gen, batch_size=4, stream_size=32, - rc=False, shifts=[0], slice_center=None, - species='human', return_augm=False, verbose=False): - self.model = model - self.seqs_gen = seqs_gen - self.batch_size = batch_size - self.stream_size = stream_size - self.rc = rc - self.shifts = shifts - self.ensembled = len(self.shifts) + int(self.rc)*len(self.shifts) - self.slice_center = slice_center - self.species = species - self.verbose = verbose - self.return_augm = return_augm - - self.stream_start = 0 - self.stream_end = 0 - - - def __getitem__(self, i): - # acquire predictions, if needed - if i >= self.stream_end: - # update start - self.stream_start = self.stream_end - - if self.verbose: - print('Predicting from %d' % self.stream_start, flush=True) - - # get next sequences - seqs_1hot = self.next_seqs() - - # predict stream - stream_preds = [] - si = 0 - while si < seqs_1hot.shape[0]: - spreds = self.model.predict_on_batch(seqs_1hot[si:si+self.batch_size]) - spreds = spreds[self.species].numpy() - stream_preds.append(spreds) - si += self.batch_size - stream_preds = np.concatenate(stream_preds, axis=0) - - # slice center - if self.slice_center is not None: - _, seq_len, _ = stream_preds.shape - mid_pos = seq_len // 2 - slice_start = mid_pos - self.slice_center//2 - slice_end = slice_start + self.slice_center - stream_preds = stream_preds[:,slice_start:slice_end,:] - - # reshape to expose augmentations - ens_seqs, seq_len, num_targets = stream_preds.shape - num_seqs = ens_seqs // self.ensembled - stream_preds = np.reshape(stream_preds, - (num_seqs, self.ensembled, seq_len, num_targets)) - - if self.return_augm: - # move augmentations to the back - self.stream_preds = np.transpose(stream_preds, [0,2,3,1]) - else: - # average augmentations - self.stream_preds = stream_preds.mean(axis=1) - - # update end - self.stream_end = self.stream_start + self.stream_preds.shape[0] - - return self.stream_preds[i - self.stream_start] - - def next_seqs(self): - """ Construct array of sequences for this stream chunk. """ - - # extract next sequences from generator - seqs_1hot = [] - stream_end = self.stream_start+self.stream_size - for si in range(self.stream_start, stream_end): - try: - seqs_1hot.append(self.seqs_gen.__next__()) - except StopIteration: - continue - - # initialize ensemble - seqs_1hot_ens = [] - - # add rc/shifts - for seq_1hot in seqs_1hot: - for shift in self.shifts: - seq_1hot_aug = dna_io.hot1_augment(seq_1hot, shift=shift) - seqs_1hot_ens.append(seq_1hot_aug) - if self.rc: - seq_1hot_aug = dna_io.hot1_rc(seq_1hot_aug) - seqs_1hot_ens.append(seq_1hot_aug) - - seqs_1hot_ens = np.array(seqs_1hot_ens, dtype='float32') - return seqs_1hot_ens