From 1cd10d70f3f665396561839cca1bcedec226e781 Mon Sep 17 00:00:00 2001 From: lruizcalico Date: Tue, 23 Apr 2024 20:09:25 -0700 Subject: [PATCH] add log dir for keras train --- src/baskerville/scripts/hound_train.py | 9 ++++++++- src/baskerville/trainer.py | 3 ++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/baskerville/scripts/hound_train.py b/src/baskerville/scripts/hound_train.py index e7ec150..7f2a20f 100755 --- a/src/baskerville/scripts/hound_train.py +++ b/src/baskerville/scripts/hound_train.py @@ -56,6 +56,12 @@ def main(): default="train_out", help="Output directory [Default: %(default)s]", ) + parser.add_argument( + "-log", + "--log_dir", + default="log_out", + help="Log directory [Default: %(default)s]", + ) parser.add_argument( "--restore", default=None, @@ -150,7 +156,7 @@ def main(): # initialize trainer seqnn_trainer = trainer.Trainer( - params_train, train_data, eval_data, args.out_dir + params_train, train_data, eval_data, args.out_dir, args.log_dir ) # compile model @@ -182,6 +188,7 @@ def main(): train_data, eval_data, args.out_dir, + args.log_dir, strategy, params_train["num_gpu"], args.keras_fit, diff --git a/src/baskerville/trainer.py b/src/baskerville/trainer.py index 5c55f52..775fe64 100644 --- a/src/baskerville/trainer.py +++ b/src/baskerville/trainer.py @@ -91,6 +91,7 @@ def __init__( train_data, eval_data, out_dir: str, + log_dir: str, strategy=None, num_gpu: int = 1, keras_fit: bool = False, @@ -188,7 +189,7 @@ def fit_keras(self, seqnn_model): callbacks = [ early_stop, - tf.keras.callbacks.TensorBoard(self.out_dir), + tf.keras.callbacks.TensorBoard(self.log_dir), tf.keras.callbacks.ModelCheckpoint("%s/model_check.h5" % self.out_dir), save_best, ]