Skip to content

Commit

Permalink
Merge pull request #24 from calico/revert-23-trt-ensemble
Browse files Browse the repository at this point in the history
Revert "Ensemble RC before TensorRT optimization"
  • Loading branch information
lruizcalico authored Apr 11, 2024
2 parents 3dd87cc + 7b122f3 commit 9031b4e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 65 deletions.
107 changes: 44 additions & 63 deletions src/baskerville/helpers/tensorrt_helpers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt
import tensorflow as tf
import tensorrt as trt
import argparse
import json
import pdb
import time

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt

from baskerville import seqnn
from baskerville import seqnn, dataset


precision_dict = {
Expand All @@ -17,6 +14,8 @@
"INT8": tf_trt.TrtPrecisionMode.INT8,
}

# For TF-TRT:


class ModelOptimizer:
"""
Expand All @@ -28,6 +27,8 @@ class ModelOptimizer:
def __init__(self, input_saved_model_dir, calibration_data=None):
self.input_saved_model_dir = input_saved_model_dir
self.calibration_data = None
self.loaded_model = None

if not calibration_data is None:
self.set_calibration_data(calibration_data)

Expand All @@ -37,101 +38,81 @@ def calibration_input_fn():

self.calibration_data = calibration_input_fn

def convert(self, precision="FP32"):
t0 = time.time()
print("Converting the model.")

def convert(
self,
output_saved_model_dir,
precision="FP32",
max_workspace_size_bytes=8000000000,
**kwargs,
):
if precision == "INT8" and self.calibration_data is None:
raise (Exception("No calibration data set!"))

trt_precision = precision_dict[precision]
conversion_params = tf_trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt_precision,
max_workspace_size_bytes=max_workspace_size_bytes,
use_calibration=precision == "INT8",
max_workspace_size_bytes=8000000000,
)
self.converter = tf_trt.TrtGraphConverterV2(
converter = tf_trt.TrtGraphConverterV2(
input_saved_model_dir=self.input_saved_model_dir,
conversion_params=conversion_params,
)

if precision == "INT8":
self.func = self.converter.convert(
calibration_input_fn=self.calibration_data
)
converter.convert(calibration_input_fn=self.calibration_data)
else:
self.func = self.converter.convert()
print("Done in %ds" % (time.time() - t0))
converter.convert()

def build(self, seq_length):
input_shape = (1, seq_length, 4)
t0 = time.time()
print("Building TRT engines for shape:", input_shape)
converter.save(output_saved_model_dir=output_saved_model_dir)

def input_fn():
x = np.random.random(input_shape).astype(np.float32)
x = tf.cast(x, tf.float32)
yield x
return output_saved_model_dir

self.converter.build(input_fn)
print("Done in %ds" % (time.time() - t0))
def predict(self, input_data):
if self.loaded_model is None:
self.load_default_model()

def build_func(self, seq_length):
input_shape = (1, seq_length, 4)
t0 = time.time()
print("Building TRT engines for shape:", input_shape)
x = np.random.random(input_shape)
x = tf.cast(x, tf.float32)
self.func(x)
print("Done in %ds" % (time.time() - t0))
return self.loaded_model.predict(input_data)

def save(self, output_dir):
self.converter.save(output_saved_model_dir=output_dir)
def load_default_model(self):
self.loaded_model = tf.keras.models.load_model("resnet50_saved_model")


def main():
parser = argparse.ArgumentParser(
description="Convert a seqnn model to TensorRT model."
)
parser.add_argument("model_fn", type=str, help="Path to the Keras model file (.h5)")
parser.add_argument("params_fn", type=str, help="Path to the JSON parameters file")
parser.add_argument(
"-t", "--targets_file", default=None, help="Path to the target variants file"
"targets_file", type=str, help="Path to the target variants file"
)
parser.add_argument(
"-o",
"--out_dir",
default="trt_out",
"output_dir",
type=str,
help="Output directory for storing saved models (original & converted)",
)
parser.add_argument(
"params_file", type=str, help="Path to the JSON parameters file"
)
parser.add_argument("model_file", help="Trained model HDF5.")
args = parser.parse_args()

# Load target variants
targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0)

# Load parameters
with open(args.params_file) as params_open:
with open(args.params_fn) as params_open:
params = json.load(params_open)

params_model = params["model"]
# Load keras model into seqnn class
seqnn_model = seqnn.SeqNN(params["model"])
seqnn_model.restore(args.model_file)

# Load target variants
if args.targets_file is not None:
targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0)
seqnn_model.build_slice(np.array(targets_df.index))

# ensemble rc
seqnn_model.build_ensemble(True)
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(args.model_fn)
seqnn_model.build_slice(np.array(targets_df.index))
# seqnn_model.build_ensemble(True)

# save this model to a directory
seqnn_model.model.save(f"{args.out_dir}/original")
seqnn_model.model.save(f"{args.output_dir}/original_model")

# Convert the model
opt_model = ModelOptimizer(f"{args.out_dir}/original")
opt_model.convert(precision="FP32")
# opt_model.build(seqnn_model.seq_length)
opt_model.save(f"{args.out_dir}/convert")
opt_model = ModelOptimizer(f"{args.output_dir}/original_model")
opt_model.convert(f"{args.output_dir}/model_FP32", precision="FP32")


if __name__ == "__main__":
Expand Down
18 changes: 16 additions & 2 deletions src/baskerville/helpers/trt_optimized_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
from baskerville import layers


class OptimizedModel:
Expand All @@ -18,6 +19,7 @@ def __init__(self, saved_model_dir=None, strand_pair=[]):
def predict(self, input_data):
if self.loaded_model_fn is None:
raise (Exception("Haven't loaded a model"))
# x = tf.constant(input_data.astype("float32"))
x = tf.cast(input_data, tf.float32)
labeling = self.loaded_model_fn(x)
try:
Expand All @@ -41,5 +43,17 @@ def load_model(self, saved_model_dir):
wrapper_fp32 = saved_model_loaded.signatures["serving_default"]
self.loaded_model_fn = wrapper_fp32

def __call__(self, x):
return self.predict(x)
def __call__(self, input_data):
# need to do the prediction for ensemble model here
x = tf.cast(input_data, tf.float32)
sequences_rev = layers.EnsembleReverseComplement()([x])
if len(self.strand_pair) == 0:
strand_pair = None
else:
strand_pair = self.strand_pair[0]
preds = [
layers.SwitchReverse(strand_pair)([self.predict(seq), rp])
for (seq, rp) in sequences_rev
]
preds_avg = tf.keras.layers.Average()(preds)
return preds_avg

0 comments on commit 9031b4e

Please sign in to comment.