From 21271c957d5dc0cc06eaa9f55d371457362ec8da Mon Sep 17 00:00:00 2001 From: hy395 Date: Tue, 23 Apr 2024 18:36:28 -0700 Subject: [PATCH] add --f16 for float16 inference --- src/baskerville/HY_helper.py | 3 +-- src/baskerville/scripts/borzoi_test_genes.py | 23 ++++++++++++++++++-- src/baskerville/scripts/hound_eval.py | 23 +++++++++++++++++--- src/baskerville/scripts/hound_eval_spec.py | 23 ++++++++++++++++++-- src/baskerville/seqnn.py | 6 +++++ 5 files changed, 69 insertions(+), 9 deletions(-) diff --git a/src/baskerville/HY_helper.py b/src/baskerville/HY_helper.py index 2d8b665..d4de926 100644 --- a/src/baskerville/HY_helper.py +++ b/src/baskerville/HY_helper.py @@ -4,7 +4,6 @@ import pyBigWig - def make_seq_1hot(genome_open, chrm, start, end, seq_len): if start < 0: seq_dna = 'N'*(-start) + genome_open.fetch(chrm, 0, end) @@ -18,7 +17,7 @@ def make_seq_1hot(genome_open, chrm, start, end, seq_len): seq_1hot = dna_io.dna_1hot(seq_dna) return seq_1hot -#Helper function to get (padded) one-hot +# Helper function to get (padded) one-hot def process_sequence(fasta_file, chrom, start, end, seq_len=524288) : fasta_open = pysam.Fastafile(fasta_file) diff --git a/src/baskerville/scripts/borzoi_test_genes.py b/src/baskerville/scripts/borzoi_test_genes.py index 1e2b853..83f1dec 100755 --- a/src/baskerville/scripts/borzoi_test_genes.py +++ b/src/baskerville/scripts/borzoi_test_genes.py @@ -27,6 +27,7 @@ from qnorm import quantile_normalize from scipy.stats import pearsonr from sklearn.metrics import explained_variance_score +from tensorflow.keras import mixed_precision from baskerville import pygene from baskerville import dataset @@ -77,6 +78,13 @@ def main(): action="store_true", help="Aggregate entire gene span [Default: %default]", ) + parser.add_option( + "--f16", + dest="f16", + default=False, + action="store_true", + help="use mixed precision for inference", + ) parser.add_option( "-t", dest="targets_file", @@ -155,8 +163,19 @@ def main(): ) # initialize model - seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(model_file, options.head_i) + ################### + # mixed precision # + ################### + if options.f16: + mixed_precision.set_global_policy('mixed_float16') # first set global policy + seqnn_model = seqnn.SeqNN(params_model) # then create model + seqnn_model.restore(model_file, options.head_i) + seqnn_model.append_activation() # add additional activation to cast float16 output to float32 + else: + # initialize model + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_file, options.head_i) + seqnn_model.build_slice(targets_df.index) seqnn_model.build_ensemble(options.rc, options.shifts) diff --git a/src/baskerville/scripts/hound_eval.py b/src/baskerville/scripts/hound_eval.py index 8db7851..b7fca0d 100755 --- a/src/baskerville/scripts/hound_eval.py +++ b/src/baskerville/scripts/hound_eval.py @@ -23,6 +23,7 @@ from scipy.stats import spearmanr import tensorflow as tf from tqdm import tqdm +from tensorflow.keras import mixed_precision from baskerville import bed from baskerville import dataset @@ -85,6 +86,12 @@ def main(): type=int, help="Step across positions [Default: %(default)s]", ) + parser.add_argument( + "--f16", + default=False, + action="store_true", + help="use mixed precision for inference", + ) parser.add_argument( "-t", "--targets_file", @@ -140,9 +147,19 @@ def main(): tfr_pattern=args.tfr_pattern, ) - # initialize model - seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(args.model_file, args.head_i) + ################### + # mixed precision # + ################### + if args.f16: + mixed_precision.set_global_policy('mixed_float16') # first set global policy + seqnn_model = seqnn.SeqNN(params_model) # then create model + seqnn_model.restore(args.model_file, args.head_i) + seqnn_model.append_activation() # add additional activation to cast float16 output to float32 + else: + # initialize model + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(args.model_file, args.head_i) + seqnn_model.build_ensemble(args.rc, args.shifts) ####################################################### diff --git a/src/baskerville/scripts/hound_eval_spec.py b/src/baskerville/scripts/hound_eval_spec.py index 0732da6..2a4608f 100755 --- a/src/baskerville/scripts/hound_eval_spec.py +++ b/src/baskerville/scripts/hound_eval_spec.py @@ -25,6 +25,7 @@ from qnorm import quantile_normalize from scipy.stats import pearsonr import tensorflow as tf +from tensorflow.keras import mixed_precision from baskerville import dataset from baskerville import seqnn @@ -74,6 +75,13 @@ def main(): type="int", help="Step across positions [Default: %default]", ) + parser.add_option( + "--f16", + dest="f16", + default=False, + action="store_true", + help="use mixed precision for inference", + ) parser.add_option( "--save", dest="save", @@ -190,8 +198,19 @@ def main(): ) # initialize model - seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(model_file, options.head_i) + ################### + # mixed precision # + ################### + if options.f16: + mixed_precision.set_global_policy('mixed_float16') # set global policy + seqnn_model = seqnn.SeqNN(params_model) # create model + seqnn_model.restore(model_file, options.head_i) + seqnn_model.append_activation() # add additional activation to cast float16 output to float32 + else: + # initialize model + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_file, options.head_i) + seqnn_model.build_slice(targets_df.index) if options.step > 1: seqnn_model.step(options.step) diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 48aa300..1ffca86 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -219,6 +219,12 @@ def build_embed(self, conv_layer_i: int, batch_norm: bool = True): inputs=self.model.inputs, outputs=conv_layer.output ) + def append_activation(self): + """add additional activation to convert float16 output to float32, required for mixed precision""" + model_0 = self.model + new_outputs = tf.keras.layers.Activation('linear', dtype='float32')(model_0.layers[-1].output) + self.model = tf.keras.Model(inputs=model_0.layers[0].input, outputs=new_outputs) + def build_ensemble(self, ensemble_rc: bool = False, ensemble_shifts=[0]): """Build ensemble of models computing on augmented input sequences.""" shift_bool = len(ensemble_shifts) > 1 or ensemble_shifts[0] != 0