diff --git a/src/baskerville/helpers/gcs_utils.py b/src/baskerville/helpers/gcs_utils.py index 35aad61..4184268 100644 --- a/src/baskerville/helpers/gcs_utils.py +++ b/src/baskerville/helpers/gcs_utils.py @@ -254,9 +254,9 @@ def download_rename_inputs(filepath: str, temp_dir: str, is_dir: bool = False) - Returns: new filepath in the local machine """ if is_dir: - download_folder_from_gcs(filepath, temp_dir) dir_name = filepath.split("/")[-1] - return temp_dir + download_folder_from_gcs(filepath, f"{temp_dir}/{dir_name}") + return f"{temp_dir}/{dir_name}" else: _, filename = split_gcs_uri(filepath) if "/" in filename: diff --git a/src/baskerville/helpers/tensorrt_helpers.py b/src/baskerville/helpers/tensorrt_helpers.py index 1703d39..2614286 100644 --- a/src/baskerville/helpers/tensorrt_helpers.py +++ b/src/baskerville/helpers/tensorrt_helpers.py @@ -34,7 +34,8 @@ def __init__(self, saved_model_dir=None): def predict(self, input_data): if self.loaded_model_fn is None: raise (Exception("Haven't loaded a model")) - x = tf.constant(input_data.astype("float32")) + # x = tf.constant(input_data.astype("float32")) + x = tf.cast(input_data, tf.float32) labeling = self.loaded_model_fn(x) try: preds = labeling["predictions"].numpy() @@ -48,7 +49,7 @@ def predict(self, input_data): raise ( Exception("Failed to get predictions from saved model object") ) - return tf.squeeze(preds, axis=0) + return preds def load_model(self, saved_model_dir): saved_model_loaded = tf.saved_model.load( @@ -58,7 +59,7 @@ def load_model(self, saved_model_dir): self.loaded_model_fn = wrapper_fp32 def __call__(self, input_data): - return self.loaded_model_fn.predict(input_data) + return tf.expand_dims(self.predict(input_data), axis=0) class ModelOptimizer: @@ -153,7 +154,7 @@ def main(): seqnn_model.build_ensemble(True) # save this model to a directory - seqnn_model.model.save(f"{args.output_dir}/original_model") + seqnn_model.ensemble.save(f"{args.output_dir}/original_model") # Convert the model opt_model = ModelOptimizer(f"{args.output_dir}/original_model") diff --git a/src/baskerville/scripts/hound_snp.py b/src/baskerville/scripts/hound_snp.py index 1ccad45..c3df976 100755 --- a/src/baskerville/scripts/hound_snp.py +++ b/src/baskerville/scripts/hound_snp.py @@ -110,6 +110,13 @@ def main(): action="store_true", help="Only run on GPU", ) + parser.add_option( + "--tensorrt", + dest="tensorrt", + default=False, + action="store_true", + help="Model type is tensorrt optimized", + ) (options, args) = parser.parse_args() if options.gcs: @@ -162,6 +169,14 @@ def main(): else: parser.error("Must provide parameters and model files and QTL VCF file") + # check if the model type is correct + if options.tensorrt: + if model_file.endswith(".h5"): + raise SystemExit("Model type is tensorrt but model file is keras") + is_dir_model = True + else: + is_dir_model = False + if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) @@ -188,7 +203,7 @@ def main(): if options.gcs: params_file = download_rename_inputs(params_file, temp_dir) vcf_file = download_rename_inputs(vcf_file, temp_dir) - model_file = download_rename_inputs(model_file, temp_dir) + model_file = download_rename_inputs(model_file, temp_dir, is_dir_model) if options.genome_fasta is not None: options.genome_fasta = download_rename_inputs( options.genome_fasta, temp_dir diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index 6ecc4b4..1f2725b 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -13,6 +13,7 @@ from baskerville import dataset from baskerville import seqnn from baskerville import vcf as bvcf +from baskerville.helpers.tensorrt_helpers import OptimizedModel def score_snps(params_file, model_file, vcf_file, worker_index, options): @@ -63,14 +64,25 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): # setup model # can we sum on GPU? - sum_length = options.snp_stats == "SAD" - seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(model_file) - seqnn_model.build_slice(targets_df.index) - if sum_length: - seqnn_model.build_sad() - seqnn_model.build_ensemble(options.rc) + + # load model + if options.tensorrt: + seqnn_model.ensemble = OptimizedModel(model_file) + input_shape = tuple(seqnn_model.model.loaded_model_fn.inputs[0].shape.as_list()) + else: + seqnn_model.restore(model_file) + seqnn_model.build_slice(targets_df.index) + sum_length = options.snp_stats == "SAD" + if sum_length: + seqnn_model.build_sad() + seqnn_model.build_ensemble(options.rc) + input_shape = seqnn_model.model.input_shape + + # make dummy predictions to warm up model + dummy_input_shape = (1,) + input_shape[1:] + dummy_input = np.random.random(dummy_input_shape).astype(np.float32) + dummy_output = seqnn_model.model(dummy_input) # shift outside seqnn num_shifts = len(options.shifts)