Skip to content

Commit

Permalink
modify snps and hound_snps
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Jan 15, 2024
1 parent a39b879 commit d95ecbd
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/baskerville/helpers/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ def download_rename_inputs(filepath: str, temp_dir: str, is_dir: bool = False) -
Returns: new filepath in the local machine
"""
if is_dir:
download_folder_from_gcs(filepath, temp_dir)
dir_name = filepath.split("/")[-1]
return temp_dir
download_folder_from_gcs(filepath, f"{temp_dir}/{dir_name}")
return f"{temp_dir}/{dir_name}"
else:
_, filename = split_gcs_uri(filepath)
if "/" in filename:
Expand Down
9 changes: 5 additions & 4 deletions src/baskerville/helpers/tensorrt_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self, saved_model_dir=None):
def predict(self, input_data):
if self.loaded_model_fn is None:
raise (Exception("Haven't loaded a model"))
x = tf.constant(input_data.astype("float32"))
# x = tf.constant(input_data.astype("float32"))
x = tf.cast(input_data, tf.float32)
labeling = self.loaded_model_fn(x)
try:
preds = labeling["predictions"].numpy()
Expand All @@ -48,7 +49,7 @@ def predict(self, input_data):
raise (
Exception("Failed to get predictions from saved model object")
)
return tf.squeeze(preds, axis=0)
return preds

def load_model(self, saved_model_dir):
saved_model_loaded = tf.saved_model.load(
Expand All @@ -58,7 +59,7 @@ def load_model(self, saved_model_dir):
self.loaded_model_fn = wrapper_fp32

def __call__(self, input_data):
return self.loaded_model_fn.predict(input_data)
return tf.expand_dims(self.predict(input_data), axis=0)


class ModelOptimizer:
Expand Down Expand Up @@ -153,7 +154,7 @@ def main():
seqnn_model.build_ensemble(True)

# save this model to a directory
seqnn_model.model.save(f"{args.output_dir}/original_model")
seqnn_model.ensemble.save(f"{args.output_dir}/original_model")

# Convert the model
opt_model = ModelOptimizer(f"{args.output_dir}/original_model")
Expand Down
17 changes: 16 additions & 1 deletion src/baskerville/scripts/hound_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ def main():
action="store_true",
help="Only run on GPU",
)
parser.add_option(
"--tensorrt",
dest="tensorrt",
default=False,
action="store_true",
help="Model type is tensorrt optimized",
)
(options, args) = parser.parse_args()

if options.gcs:
Expand Down Expand Up @@ -162,6 +169,14 @@ def main():
else:
parser.error("Must provide parameters and model files and QTL VCF file")

# check if the model type is correct
if options.tensorrt:
if model_file.endswith(".h5"):
raise SystemExit("Model type is tensorrt but model file is keras")
is_dir_model = True
else:
is_dir_model = False

if not os.path.isdir(options.out_dir):
os.mkdir(options.out_dir)

Expand All @@ -188,7 +203,7 @@ def main():
if options.gcs:
params_file = download_rename_inputs(params_file, temp_dir)
vcf_file = download_rename_inputs(vcf_file, temp_dir)
model_file = download_rename_inputs(model_file, temp_dir)
model_file = download_rename_inputs(model_file, temp_dir, is_dir_model)
if options.genome_fasta is not None:
options.genome_fasta = download_rename_inputs(
options.genome_fasta, temp_dir
Expand Down
26 changes: 19 additions & 7 deletions src/baskerville/snps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from baskerville import dataset
from baskerville import seqnn
from baskerville import vcf as bvcf
from baskerville.helpers.tensorrt_helpers import OptimizedModel


def score_snps(params_file, model_file, vcf_file, worker_index, options):
Expand Down Expand Up @@ -63,14 +64,25 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):
# setup model

# can we sum on GPU?
sum_length = options.snp_stats == "SAD"

seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file)
seqnn_model.build_slice(targets_df.index)
if sum_length:
seqnn_model.build_sad()
seqnn_model.build_ensemble(options.rc)

# load model
if options.tensorrt:
seqnn_model.ensemble = OptimizedModel(model_file)
input_shape = tuple(seqnn_model.model.loaded_model_fn.inputs[0].shape.as_list())
else:
seqnn_model.restore(model_file)
seqnn_model.build_slice(targets_df.index)
sum_length = options.snp_stats == "SAD"
if sum_length:
seqnn_model.build_sad()
seqnn_model.build_ensemble(options.rc)
input_shape = seqnn_model.model.input_shape

# make dummy predictions to warm up model
dummy_input_shape = (1,) + input_shape[1:]
dummy_input = np.random.random(dummy_input_shape).astype(np.float32)
dummy_output = seqnn_model.model(dummy_input)

# shift outside seqnn
num_shifts = len(options.shifts)
Expand Down

0 comments on commit d95ecbd

Please sign in to comment.