Skip to content

Commit

Permalink
add log dir for keras train
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Apr 24, 2024
1 parent 9031b4e commit 1cd10d7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/baskerville/scripts/hound_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def main():
default="train_out",
help="Output directory [Default: %(default)s]",
)
parser.add_argument(
"-log",
"--log_dir",
default="log_out",
help="Log directory [Default: %(default)s]",
)
parser.add_argument(
"--restore",
default=None,
Expand Down Expand Up @@ -150,7 +156,7 @@ def main():

# initialize trainer
seqnn_trainer = trainer.Trainer(
params_train, train_data, eval_data, args.out_dir
params_train, train_data, eval_data, args.out_dir, args.log_dir
)

# compile model
Expand Down Expand Up @@ -182,6 +188,7 @@ def main():
train_data,
eval_data,
args.out_dir,
args.log_dir,
strategy,
params_train["num_gpu"],
args.keras_fit,
Expand Down
3 changes: 2 additions & 1 deletion src/baskerville/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
train_data,
eval_data,
out_dir: str,
log_dir: str,
strategy=None,
num_gpu: int = 1,
keras_fit: bool = False,
Expand Down Expand Up @@ -188,7 +189,7 @@ def fit_keras(self, seqnn_model):

callbacks = [
early_stop,
tf.keras.callbacks.TensorBoard(self.out_dir),
tf.keras.callbacks.TensorBoard(self.log_dir),
tf.keras.callbacks.ModelCheckpoint("%s/model_check.h5" % self.out_dir),
save_best,
]
Expand Down

0 comments on commit 1cd10d7

Please sign in to comment.