From fee788db7ca29eb9e6d5b5dd0355ea4f97b25979 Mon Sep 17 00:00:00 2001 From: lruizcalico Date: Tue, 23 Jan 2024 19:21:48 -0800 Subject: [PATCH] add support to tensorrt --- setup.cfg | 3 + src/baskerville/helpers/gcs_utils.py | 4 +- src/baskerville/helpers/tensorrt_helpers.py | 119 ++++++++++++++++++ .../helpers/trt_optimized_model.py | 59 +++++++++ src/baskerville/scripts/hound_snp.py | 17 ++- src/baskerville/snps.py | 26 ++-- 6 files changed, 218 insertions(+), 10 deletions(-) create mode 100644 src/baskerville/helpers/tensorrt_helpers.py create mode 100644 src/baskerville/helpers/trt_optimized_model.py diff --git a/setup.cfg b/setup.cfg index cb2a2f0..a5662f6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,5 +45,8 @@ dev = black==22.3.0 pytest==7.1.2 +gpu = + tensorrt==8.6.1 + [options.packages.find] where = src diff --git a/src/baskerville/helpers/gcs_utils.py b/src/baskerville/helpers/gcs_utils.py index 35aad61..4184268 100644 --- a/src/baskerville/helpers/gcs_utils.py +++ b/src/baskerville/helpers/gcs_utils.py @@ -254,9 +254,9 @@ def download_rename_inputs(filepath: str, temp_dir: str, is_dir: bool = False) - Returns: new filepath in the local machine """ if is_dir: - download_folder_from_gcs(filepath, temp_dir) dir_name = filepath.split("/")[-1] - return temp_dir + download_folder_from_gcs(filepath, f"{temp_dir}/{dir_name}") + return f"{temp_dir}/{dir_name}" else: _, filename = split_gcs_uri(filepath) if "/" in filename: diff --git a/src/baskerville/helpers/tensorrt_helpers.py b/src/baskerville/helpers/tensorrt_helpers.py new file mode 100644 index 0000000..504c2f3 --- /dev/null +++ b/src/baskerville/helpers/tensorrt_helpers.py @@ -0,0 +1,119 @@ +from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt +import tensorflow as tf +import tensorrt as trt +import argparse +import json +import numpy as np +import pandas as pd +from baskerville import seqnn, dataset + + +precision_dict = { + "FP32": tf_trt.TrtPrecisionMode.FP32, + "FP16": tf_trt.TrtPrecisionMode.FP16, + "INT8": tf_trt.TrtPrecisionMode.INT8, +} + +# For TF-TRT: + + +class ModelOptimizer: + """ + Class of converter for tensorrt + Args: + input_saved_model_dir: Folder with saved model of the input model + """ + + def __init__(self, input_saved_model_dir, calibration_data=None): + self.input_saved_model_dir = input_saved_model_dir + self.calibration_data = None + self.loaded_model = None + + if not calibration_data is None: + self.set_calibration_data(calibration_data) + + def set_calibration_data(self, calibration_data): + def calibration_input_fn(): + yield (tf.constant(calibration_data.astype("float32")),) + + self.calibration_data = calibration_input_fn + + def convert( + self, + output_saved_model_dir, + precision="FP32", + max_workspace_size_bytes=8000000000, + **kwargs, + ): + if precision == "INT8" and self.calibration_data is None: + raise (Exception("No calibration data set!")) + + trt_precision = precision_dict[precision] + conversion_params = tf_trt.DEFAULT_TRT_CONVERSION_PARAMS._replace( + precision_mode=trt_precision, + max_workspace_size_bytes=max_workspace_size_bytes, + use_calibration=precision == "INT8", + ) + converter = tf_trt.TrtGraphConverterV2( + input_saved_model_dir=self.input_saved_model_dir, + conversion_params=conversion_params, + ) + + if precision == "INT8": + converter.convert(calibration_input_fn=self.calibration_data) + else: + converter.convert() + + converter.save(output_saved_model_dir=output_saved_model_dir) + + return output_saved_model_dir + + def predict(self, input_data): + if self.loaded_model is None: + self.load_default_model() + + return self.loaded_model.predict(input_data) + + def load_default_model(self): + self.loaded_model = tf.keras.models.load_model("resnet50_saved_model") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert a seqnn model to TensorRT model." + ) + parser.add_argument("model_fn", type=str, help="Path to the Keras model file (.h5)") + parser.add_argument("params_fn", type=str, help="Path to the JSON parameters file") + parser.add_argument( + "targets_file", type=str, help="Path to the target variants file" + ) + parser.add_argument( + "output_dir", + type=str, + help="Output directory for storing saved models (original & converted)", + ) + args = parser.parse_args() + + # Load target variants + targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0) + + # Load parameters + with open(args.params_fn) as params_open: + params = json.load(params_open) + params_model = params["model"] + # Load keras model into seqnn class + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(args.model_fn) + seqnn_model.build_slice(np.array(targets_df.index)) + # seqnn_model.build_ensemble(True) + + # save this model to a directory + seqnn_model.model.save(f"{args.output_dir}/original_model") + + # Convert the model + opt_model = ModelOptimizer(f"{args.output_dir}/original_model") + opt_model.convert(f"{args.output_dir}/model_FP32", precision="FP32") + + +if __name__ == "__main__": + main() diff --git a/src/baskerville/helpers/trt_optimized_model.py b/src/baskerville/helpers/trt_optimized_model.py new file mode 100644 index 0000000..2032b72 --- /dev/null +++ b/src/baskerville/helpers/trt_optimized_model.py @@ -0,0 +1,59 @@ +import tensorflow as tf +from tensorflow.python.saved_model import tag_constants +from baskerville import layers + + +class OptimizedModel: + """ + Class of model optimized with tensorrt + Args: + saved_model_dir: Folder with saved model + """ + + def __init__(self, saved_model_dir=None, strand_pair=[]): + self.loaded_model_fn = None + self.strand_pair = strand_pair + if not saved_model_dir is None: + self.load_model(saved_model_dir) + + def predict(self, input_data): + if self.loaded_model_fn is None: + raise (Exception("Haven't loaded a model")) + # x = tf.constant(input_data.astype("float32")) + x = tf.cast(input_data, tf.float32) + labeling = self.loaded_model_fn(x) + try: + preds = labeling["predictions"].numpy() + except: + try: + preds = labeling["probs"].numpy() + except: + try: + preds = labeling[next(iter(labeling.keys()))] + except: + raise ( + Exception("Failed to get predictions from saved model object") + ) + return preds + + def load_model(self, saved_model_dir): + saved_model_loaded = tf.saved_model.load( + saved_model_dir, tags=[tag_constants.SERVING] + ) + wrapper_fp32 = saved_model_loaded.signatures["serving_default"] + self.loaded_model_fn = wrapper_fp32 + + def __call__(self, input_data): + # need to do the prediction for ensemble model here + x = tf.cast(input_data, tf.float32) + sequences_rev = layers.EnsembleReverseComplement()([x]) + if len(self.strand_pair) == 0: + strand_pair = None + else: + strand_pair = self.strand_pair[0] + preds = [ + layers.SwitchReverse(strand_pair)([self.predict(seq), rp]) + for (seq, rp) in sequences_rev + ] + preds_avg = tf.keras.layers.Average()(preds) + return preds_avg diff --git a/src/baskerville/scripts/hound_snp.py b/src/baskerville/scripts/hound_snp.py index 1ccad45..c3df976 100755 --- a/src/baskerville/scripts/hound_snp.py +++ b/src/baskerville/scripts/hound_snp.py @@ -110,6 +110,13 @@ def main(): action="store_true", help="Only run on GPU", ) + parser.add_option( + "--tensorrt", + dest="tensorrt", + default=False, + action="store_true", + help="Model type is tensorrt optimized", + ) (options, args) = parser.parse_args() if options.gcs: @@ -162,6 +169,14 @@ def main(): else: parser.error("Must provide parameters and model files and QTL VCF file") + # check if the model type is correct + if options.tensorrt: + if model_file.endswith(".h5"): + raise SystemExit("Model type is tensorrt but model file is keras") + is_dir_model = True + else: + is_dir_model = False + if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) @@ -188,7 +203,7 @@ def main(): if options.gcs: params_file = download_rename_inputs(params_file, temp_dir) vcf_file = download_rename_inputs(vcf_file, temp_dir) - model_file = download_rename_inputs(model_file, temp_dir) + model_file = download_rename_inputs(model_file, temp_dir, is_dir_model) if options.genome_fasta is not None: options.genome_fasta = download_rename_inputs( options.genome_fasta, temp_dir diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index 6ecc4b4..f6cff3e 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -13,6 +13,7 @@ from baskerville import dataset from baskerville import seqnn from baskerville import vcf as bvcf +from baskerville.helpers.trt_optimized_model import OptimizedModel def score_snps(params_file, model_file, vcf_file, worker_index, options): @@ -63,14 +64,25 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): # setup model # can we sum on GPU? - sum_length = options.snp_stats == "SAD" - seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(model_file) - seqnn_model.build_slice(targets_df.index) - if sum_length: - seqnn_model.build_sad() - seqnn_model.build_ensemble(options.rc) + + # load model + sum_length = options.snp_stats == "SAD" + if options.tensorrt: + seqnn_model.model = OptimizedModel(model_file, seqnn_model.strand_pair) + input_shape = tuple(seqnn_model.model.loaded_model_fn.inputs[0].shape.as_list()) + else: + seqnn_model.restore(model_file) + seqnn_model.build_slice(targets_df.index) + if sum_length: + seqnn_model.build_sad() + seqnn_model.build_ensemble(options.rc) + input_shape = seqnn_model.model.input_shape + + # make dummy predictions to warm up model + dummy_input_shape = (1,) + input_shape[1:] + dummy_input = np.random.random(dummy_input_shape).astype(np.float32) + dummy_output = seqnn_model(dummy_input) # shift outside seqnn num_shifts = len(options.shifts)