Skip to content

Commit

Permalink
Merge pull request #25 from calico/trt-ensemble-redo
Browse files Browse the repository at this point in the history
readd changes for ensemble before optimizing
  • Loading branch information
lruizcalico authored May 8, 2024
2 parents 69e46af + 895eda2 commit 8ab0462
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 60 deletions.
107 changes: 63 additions & 44 deletions src/baskerville/helpers/tensorrt_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt
import tensorflow as tf
import tensorrt as trt
import argparse
import json
import pdb
import time

import numpy as np
import pandas as pd
from baskerville import seqnn, dataset
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt

from baskerville import seqnn


precision_dict = {
Expand All @@ -14,8 +17,6 @@
"INT8": tf_trt.TrtPrecisionMode.INT8,
}

# For TF-TRT:


class ModelOptimizer:
"""
Expand All @@ -27,8 +28,6 @@ class ModelOptimizer:
def __init__(self, input_saved_model_dir, calibration_data=None):
self.input_saved_model_dir = input_saved_model_dir
self.calibration_data = None
self.loaded_model = None

if not calibration_data is None:
self.set_calibration_data(calibration_data)

Expand All @@ -38,81 +37,101 @@ def calibration_input_fn():

self.calibration_data = calibration_input_fn

def convert(
self,
output_saved_model_dir,
precision="FP32",
max_workspace_size_bytes=8000000000,
**kwargs,
):
def convert(self, precision="FP32"):
t0 = time.time()
print("Converting the model.")

if precision == "INT8" and self.calibration_data is None:
raise (Exception("No calibration data set!"))

trt_precision = precision_dict[precision]
conversion_params = tf_trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt_precision,
max_workspace_size_bytes=max_workspace_size_bytes,
use_calibration=precision == "INT8",
max_workspace_size_bytes=8000000000,
)
converter = tf_trt.TrtGraphConverterV2(
self.converter = tf_trt.TrtGraphConverterV2(
input_saved_model_dir=self.input_saved_model_dir,
conversion_params=conversion_params,
)

if precision == "INT8":
converter.convert(calibration_input_fn=self.calibration_data)
self.func = self.converter.convert(
calibration_input_fn=self.calibration_data
)
else:
converter.convert()
self.func = self.converter.convert()
print("Done in %ds" % (time.time() - t0))

converter.save(output_saved_model_dir=output_saved_model_dir)
def build(self, seq_length):
input_shape = (1, seq_length, 4)
t0 = time.time()
print("Building TRT engines for shape:", input_shape)

return output_saved_model_dir
def input_fn():
x = np.random.random(input_shape).astype(np.float32)
x = tf.cast(x, tf.float32)
yield x

def predict(self, input_data):
if self.loaded_model is None:
self.load_default_model()
self.converter.build(input_fn)
print("Done in %ds" % (time.time() - t0))

return self.loaded_model.predict(input_data)
def build_func(self, seq_length):
input_shape = (1, seq_length, 4)
t0 = time.time()
print("Building TRT engines for shape:", input_shape)
x = np.random.random(input_shape)
x = tf.cast(x, tf.float32)
self.func(x)
print("Done in %ds" % (time.time() - t0))

def load_default_model(self):
self.loaded_model = tf.keras.models.load_model("resnet50_saved_model")
def save(self, output_dir):
self.converter.save(output_saved_model_dir=output_dir)


def main():
parser = argparse.ArgumentParser(
description="Convert a seqnn model to TensorRT model."
)
parser.add_argument("model_fn", type=str, help="Path to the Keras model file (.h5)")
parser.add_argument("params_fn", type=str, help="Path to the JSON parameters file")
parser.add_argument(
"targets_file", type=str, help="Path to the target variants file"
"-t", "--targets_file", default=None, help="Path to the target variants file"
)
parser.add_argument(
"output_dir",
type=str,
"-o",
"--out_dir",
default="trt_out",
help="Output directory for storing saved models (original & converted)",
)
parser.add_argument(
"params_file", type=str, help="Path to the JSON parameters file"
)
parser.add_argument("model_file", help="Trained model HDF5.")
args = parser.parse_args()

# Load target variants
targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0)

# Load parameters
with open(args.params_fn) as params_open:
with open(args.params_file) as params_open:
params = json.load(params_open)
params_model = params["model"]

# Load keras model into seqnn class
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(args.model_fn)
seqnn_model.build_slice(np.array(targets_df.index))
# seqnn_model.build_ensemble(True)
seqnn_model = seqnn.SeqNN(params["model"])
seqnn_model.restore(args.model_file)

# Load target variants
if args.targets_file is not None:
targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0)
seqnn_model.build_slice(np.array(targets_df.index))

# ensemble rc
seqnn_model.build_ensemble(True)

# save this model to a directory
seqnn_model.model.save(f"{args.output_dir}/original_model")
seqnn_model.ensemble.save(f"{args.out_dir}/original")

# Convert the model
opt_model = ModelOptimizer(f"{args.output_dir}/original_model")
opt_model.convert(f"{args.output_dir}/model_FP32", precision="FP32")
opt_model = ModelOptimizer(f"{args.out_dir}/original")
opt_model.convert(precision="FP32")
# opt_model.build(seqnn_model.seq_length)
opt_model.save(f"{args.out_dir}/convert")


if __name__ == "__main__":
Expand Down
18 changes: 2 additions & 16 deletions src/baskerville/helpers/trt_optimized_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
from baskerville import layers


class OptimizedModel:
Expand All @@ -19,7 +18,6 @@ def __init__(self, saved_model_dir=None, strand_pair=[]):
def predict(self, input_data):
if self.loaded_model_fn is None:
raise (Exception("Haven't loaded a model"))
# x = tf.constant(input_data.astype("float32"))
x = tf.cast(input_data, tf.float32)
labeling = self.loaded_model_fn(x)
try:
Expand All @@ -43,17 +41,5 @@ def load_model(self, saved_model_dir):
wrapper_fp32 = saved_model_loaded.signatures["serving_default"]
self.loaded_model_fn = wrapper_fp32

def __call__(self, input_data):
# need to do the prediction for ensemble model here
x = tf.cast(input_data, tf.float32)
sequences_rev = layers.EnsembleReverseComplement()([x])
if len(self.strand_pair) == 0:
strand_pair = None
else:
strand_pair = self.strand_pair[0]
preds = [
layers.SwitchReverse(strand_pair)([self.predict(seq), rp])
for (seq, rp) in sequences_rev
]
preds_avg = tf.keras.layers.Average()(preds)
return preds_avg
def __call__(self, x):
return self.predict(x)

0 comments on commit 8ab0462

Please sign in to comment.