Skip to content

Commit

Permalink
add support to tensorrt
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Jan 24, 2024
1 parent 929993d commit fee788d
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 10 deletions.
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,8 @@ dev =
black==22.3.0
pytest==7.1.2

gpu =
tensorrt==8.6.1

[options.packages.find]
where = src
4 changes: 2 additions & 2 deletions src/baskerville/helpers/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ def download_rename_inputs(filepath: str, temp_dir: str, is_dir: bool = False) -
Returns: new filepath in the local machine
"""
if is_dir:
download_folder_from_gcs(filepath, temp_dir)
dir_name = filepath.split("/")[-1]
return temp_dir
download_folder_from_gcs(filepath, f"{temp_dir}/{dir_name}")
return f"{temp_dir}/{dir_name}"
else:
_, filename = split_gcs_uri(filepath)
if "/" in filename:
Expand Down
119 changes: 119 additions & 0 deletions src/baskerville/helpers/tensorrt_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt
import tensorflow as tf
import tensorrt as trt
import argparse
import json
import numpy as np
import pandas as pd
from baskerville import seqnn, dataset


precision_dict = {
"FP32": tf_trt.TrtPrecisionMode.FP32,
"FP16": tf_trt.TrtPrecisionMode.FP16,
"INT8": tf_trt.TrtPrecisionMode.INT8,
}

# For TF-TRT:


class ModelOptimizer:
"""
Class of converter for tensorrt
Args:
input_saved_model_dir: Folder with saved model of the input model
"""

def __init__(self, input_saved_model_dir, calibration_data=None):
self.input_saved_model_dir = input_saved_model_dir
self.calibration_data = None
self.loaded_model = None

if not calibration_data is None:
self.set_calibration_data(calibration_data)

def set_calibration_data(self, calibration_data):
def calibration_input_fn():
yield (tf.constant(calibration_data.astype("float32")),)

self.calibration_data = calibration_input_fn

def convert(
self,
output_saved_model_dir,
precision="FP32",
max_workspace_size_bytes=8000000000,
**kwargs,
):
if precision == "INT8" and self.calibration_data is None:
raise (Exception("No calibration data set!"))

trt_precision = precision_dict[precision]
conversion_params = tf_trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt_precision,
max_workspace_size_bytes=max_workspace_size_bytes,
use_calibration=precision == "INT8",
)
converter = tf_trt.TrtGraphConverterV2(
input_saved_model_dir=self.input_saved_model_dir,
conversion_params=conversion_params,
)

if precision == "INT8":
converter.convert(calibration_input_fn=self.calibration_data)
else:
converter.convert()

converter.save(output_saved_model_dir=output_saved_model_dir)

return output_saved_model_dir

def predict(self, input_data):
if self.loaded_model is None:
self.load_default_model()

return self.loaded_model.predict(input_data)

def load_default_model(self):
self.loaded_model = tf.keras.models.load_model("resnet50_saved_model")


def main():
parser = argparse.ArgumentParser(
description="Convert a seqnn model to TensorRT model."
)
parser.add_argument("model_fn", type=str, help="Path to the Keras model file (.h5)")
parser.add_argument("params_fn", type=str, help="Path to the JSON parameters file")
parser.add_argument(
"targets_file", type=str, help="Path to the target variants file"
)
parser.add_argument(
"output_dir",
type=str,
help="Output directory for storing saved models (original & converted)",
)
args = parser.parse_args()

# Load target variants
targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0)

# Load parameters
with open(args.params_fn) as params_open:
params = json.load(params_open)
params_model = params["model"]
# Load keras model into seqnn class
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(args.model_fn)
seqnn_model.build_slice(np.array(targets_df.index))
# seqnn_model.build_ensemble(True)

# save this model to a directory
seqnn_model.model.save(f"{args.output_dir}/original_model")

# Convert the model
opt_model = ModelOptimizer(f"{args.output_dir}/original_model")
opt_model.convert(f"{args.output_dir}/model_FP32", precision="FP32")


if __name__ == "__main__":
main()
59 changes: 59 additions & 0 deletions src/baskerville/helpers/trt_optimized_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
from baskerville import layers


class OptimizedModel:
"""
Class of model optimized with tensorrt
Args:
saved_model_dir: Folder with saved model
"""

def __init__(self, saved_model_dir=None, strand_pair=[]):
self.loaded_model_fn = None
self.strand_pair = strand_pair
if not saved_model_dir is None:
self.load_model(saved_model_dir)

def predict(self, input_data):
if self.loaded_model_fn is None:
raise (Exception("Haven't loaded a model"))
# x = tf.constant(input_data.astype("float32"))
x = tf.cast(input_data, tf.float32)
labeling = self.loaded_model_fn(x)
try:
preds = labeling["predictions"].numpy()
except:
try:
preds = labeling["probs"].numpy()
except:
try:
preds = labeling[next(iter(labeling.keys()))]
except:
raise (
Exception("Failed to get predictions from saved model object")
)
return preds

def load_model(self, saved_model_dir):
saved_model_loaded = tf.saved_model.load(
saved_model_dir, tags=[tag_constants.SERVING]
)
wrapper_fp32 = saved_model_loaded.signatures["serving_default"]
self.loaded_model_fn = wrapper_fp32

def __call__(self, input_data):
# need to do the prediction for ensemble model here
x = tf.cast(input_data, tf.float32)
sequences_rev = layers.EnsembleReverseComplement()([x])
if len(self.strand_pair) == 0:
strand_pair = None
else:
strand_pair = self.strand_pair[0]
preds = [
layers.SwitchReverse(strand_pair)([self.predict(seq), rp])
for (seq, rp) in sequences_rev
]
preds_avg = tf.keras.layers.Average()(preds)
return preds_avg
17 changes: 16 additions & 1 deletion src/baskerville/scripts/hound_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ def main():
action="store_true",
help="Only run on GPU",
)
parser.add_option(
"--tensorrt",
dest="tensorrt",
default=False,
action="store_true",
help="Model type is tensorrt optimized",
)
(options, args) = parser.parse_args()

if options.gcs:
Expand Down Expand Up @@ -162,6 +169,14 @@ def main():
else:
parser.error("Must provide parameters and model files and QTL VCF file")

# check if the model type is correct
if options.tensorrt:
if model_file.endswith(".h5"):
raise SystemExit("Model type is tensorrt but model file is keras")
is_dir_model = True
else:
is_dir_model = False

if not os.path.isdir(options.out_dir):
os.mkdir(options.out_dir)

Expand All @@ -188,7 +203,7 @@ def main():
if options.gcs:
params_file = download_rename_inputs(params_file, temp_dir)
vcf_file = download_rename_inputs(vcf_file, temp_dir)
model_file = download_rename_inputs(model_file, temp_dir)
model_file = download_rename_inputs(model_file, temp_dir, is_dir_model)
if options.genome_fasta is not None:
options.genome_fasta = download_rename_inputs(
options.genome_fasta, temp_dir
Expand Down
26 changes: 19 additions & 7 deletions src/baskerville/snps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from baskerville import dataset
from baskerville import seqnn
from baskerville import vcf as bvcf
from baskerville.helpers.trt_optimized_model import OptimizedModel


def score_snps(params_file, model_file, vcf_file, worker_index, options):
Expand Down Expand Up @@ -63,14 +64,25 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):
# setup model

# can we sum on GPU?
sum_length = options.snp_stats == "SAD"

seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file)
seqnn_model.build_slice(targets_df.index)
if sum_length:
seqnn_model.build_sad()
seqnn_model.build_ensemble(options.rc)

# load model
sum_length = options.snp_stats == "SAD"
if options.tensorrt:
seqnn_model.model = OptimizedModel(model_file, seqnn_model.strand_pair)
input_shape = tuple(seqnn_model.model.loaded_model_fn.inputs[0].shape.as_list())
else:
seqnn_model.restore(model_file)
seqnn_model.build_slice(targets_df.index)
if sum_length:
seqnn_model.build_sad()
seqnn_model.build_ensemble(options.rc)
input_shape = seqnn_model.model.input_shape

# make dummy predictions to warm up model
dummy_input_shape = (1,) + input_shape[1:]
dummy_input = np.random.random(dummy_input_shape).astype(np.float32)
dummy_output = seqnn_model(dummy_input)

# shift outside seqnn
num_shifts = len(options.shifts)
Expand Down

0 comments on commit fee788d

Please sign in to comment.