diff --git a/src/baskerville/trainer.py b/src/baskerville/trainer.py index 75e538e..91f00d9 100644 --- a/src/baskerville/trainer.py +++ b/src/baskerville/trainer.py @@ -17,7 +17,8 @@ import numpy as np import tensorflow as tf - +import tempfile +from baskerville.helpers.gcs_utils import is_gcs_path, upload_folder_gcs from baskerville import metrics @@ -119,6 +120,15 @@ def __init__( self.batch_size = self.train_data[0].batch_size self.compiled = False + # if log_dir is in gcs then create a local temp dir + if is_gcs_path(self.log_dir): + folder_name = self.log_dir.split("/")[-1] + self.log_dir = tempfile.mkdtemp() + folder_name + self.gcs_log_dir = log_dir + self.gcs = True + else: + self.gcs = False + # early stopping self.patience = self.params.get("patience", 20) @@ -498,6 +508,10 @@ def eval_step1_distr(xd, yd): print(" - valid_r2: %.4f" % valid_r2[di].result().numpy(), end="") early_stop_stat = valid_r[di].result().numpy() + # upload to gcs + if self.gcs: + upload_folder_gcs(self.log_dir, self.gcs_log_dir) + # checkpoint managers[di].save() model.save( @@ -697,6 +711,10 @@ def eval_step_distr(xd, yd): end="", ) + # upload to gcs + if self.gcs: + upload_folder_gcs(self.log_dir, self.gcs_log_dir) + # checkpoint manager.save() seqnn_model.save("%s/model_check.h5" % self.out_dir)