Skip to content

Commit

Permalink
rewrite prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Jan 16, 2024
1 parent 2d6442e commit cb75751
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/baskerville/helpers/tensorrt_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import numpy as np
import pandas as pd
from baskerville import seqnn
from baskerville import seqnn, layers


precision_dict = {
Expand Down Expand Up @@ -59,7 +59,15 @@ def load_model(self, saved_model_dir):
self.loaded_model_fn = wrapper_fp32

def __call__(self, input_data):
return tf.expand_dims(self.predict(input_data), axis=0)
# need to do the prediction for ensemble model here
x = tf.cast(input_data, tf.float32)
sequences_rev = layers.EnsembleReverseComplement()([x])
preds = [
layers.SwitchReverse(None)([self.predict(seq), rp])
for (seq, rp) in sequences_rev
]
preds_avg = tf.keras.layers.Average()(preds)
return tf.expand_dims(preds_avg, axis=0)


class ModelOptimizer:
Expand Down Expand Up @@ -151,10 +159,10 @@ def main():
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(args.model_fn)
seqnn_model.build_slice(np.array(targets_df.index))
seqnn_model.build_ensemble(True)
# seqnn_model.build_ensemble(True)

# save this model to a directory
seqnn_model.ensemble.save(f"{args.output_dir}/original_model")
seqnn_model.model.save(f"{args.output_dir}/original_model")

# Convert the model
opt_model = ModelOptimizer(f"{args.output_dir}/original_model")
Expand Down

0 comments on commit cb75751

Please sign in to comment.