Skip to content

Commit

Permalink
Merge branch 'main' of github.com:calico/baskerville into main
Browse files Browse the repository at this point in the history
  • Loading branch information
davek44 committed Jun 11, 2024
2 parents 7fe2b78 + 260a3ac commit b6f5c3c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/baskerville/scripts/hound_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def main():
default="train_out",
help="Output directory [Default: %(default)s]",
)
parser.add_argument(
"-l",
"--log_dir",
default="log_out",
help="Tensorboard log directory [Default: %(default)s]",
)
parser.add_argument(
"--restore",
default=None,
Expand Down Expand Up @@ -150,7 +156,7 @@ def main():

# initialize trainer
seqnn_trainer = trainer.Trainer(
params_train, train_data, eval_data, args.out_dir
params_train, train_data, eval_data, args.out_dir, args.log_dir
)

# compile model
Expand Down Expand Up @@ -182,6 +188,7 @@ def main():
train_data,
eval_data,
args.out_dir,
args.log_dir,
strategy,
params_train["num_gpu"],
args.keras_fit,
Expand Down
45 changes: 44 additions & 1 deletion src/baskerville/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
train_data,
eval_data,
out_dir: str,
log_dir: str,
strategy=None,
num_gpu: int = 1,
keras_fit: bool = False,
Expand All @@ -112,6 +113,7 @@ def __init__(
if type(self.eval_data) is not list:
self.eval_data = [self.eval_data]
self.out_dir = out_dir
self.log_dir = log_dir
self.strategy = strategy
self.num_gpu = num_gpu
self.batch_size = self.train_data[0].batch_size
Expand Down Expand Up @@ -205,7 +207,7 @@ def fit_keras(self, seqnn_model):

callbacks = [
early_stop,
tf.keras.callbacks.TensorBoard(self.out_dir),
tf.keras.callbacks.TensorBoard(self.log_dir, histogram_freq=1),
tf.keras.callbacks.ModelCheckpoint("%s/model_check.h5" % self.out_dir),
save_best,
]
Expand Down Expand Up @@ -414,6 +416,12 @@ def eval_step1_distr(xd, yd):
# training loop

first_step = True
# set up summary writer
train_log_dir = self.log_dir + "/train"
valid_log_dir = self.log_dir + "/valid"
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
valid_summary_writer = tf.summary.create_file_writer(valid_log_dir)

for ei in range(epoch_start, self.train_epochs_max):
if ei >= self.train_epochs_min and np.min(unimproved) > self.patience:
break
Expand Down Expand Up @@ -446,6 +454,13 @@ def eval_step1_distr(xd, yd):
for di in range(self.num_datasets):
print(" Data %d" % di, end="")
model = seqnn_model.models[di]
with train_summary_writer.as_default():
tf.summary.scalar(
"loss", train_loss[di].result().numpy(), step=ei
)
tf.summary.scalar("r", train_r[di].result().numpy(), step=ei)
tf.summary.scalar("r2", train_r2[di].result().numpy(), step=ei)
train_summary_writer.flush()

# print training accuracy
print(
Expand All @@ -467,6 +482,14 @@ def eval_step1_distr(xd, yd):
else:
eval_step1_distr(x, y)

with valid_summary_writer.as_default():
tf.summary.scalar(
"loss", valid_loss[di].result().numpy(), step=ei
)
tf.summary.scalar("r", valid_r[di].result().numpy(), step=ei)
tf.summary.scalar("r2", valid_r2[di].result().numpy(), step=ei)
valid_summary_writer.flush()

# print validation accuracy
print(
" - valid_loss: %.4f" % valid_loss[di].result().numpy(), end=""
Expand Down Expand Up @@ -604,6 +627,12 @@ def eval_step_distr(xd, yd):
valid_best = -np.inf
unimproved = 0

# set up summary writer
train_log_dir = self.log_dir + "/train"
valid_log_dir = self.log_dir + "/valid"
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
valid_summary_writer = tf.summary.create_file_writer(valid_log_dir)

# training loop
for ei in range(epoch_start, self.train_epochs_max):
if ei >= self.train_epochs_min and unimproved > self.patience:
Expand Down Expand Up @@ -632,6 +661,13 @@ def eval_step_distr(xd, yd):
train_loss_epoch = train_loss.result().numpy()
train_r_epoch = train_r.result().numpy()
train_r2_epoch = train_r2.result().numpy()

with train_summary_writer.as_default():
tf.summary.scalar("loss", train_loss_epoch, step=ei)
tf.summary.scalar("r", train_r_epoch, step=ei)
tf.summary.scalar("r2", train_r2_epoch, step=ei)
train_summary_writer.flush()

print(
"Epoch %d - %ds - train_loss: %.4f - train_r: %.4f - train_r2: %.4f"
% (
Expand All @@ -648,6 +684,13 @@ def eval_step_distr(xd, yd):
valid_loss_epoch = valid_loss.result().numpy()
valid_r_epoch = valid_r.result().numpy()
valid_r2_epoch = valid_r2.result().numpy()

with valid_summary_writer.as_default():
tf.summary.scalar("loss", valid_loss_epoch, step=ei)
tf.summary.scalar("r", valid_r_epoch, step=ei)
tf.summary.scalar("r2", valid_r2_epoch, step=ei)
valid_summary_writer.flush()

print(
" - valid_loss: %.4f - valid_r: %.4f - valid_r2: %.4f"
% (valid_loss_epoch, valid_r_epoch, valid_r2_epoch),
Expand Down
4 changes: 4 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def test_train(clean_data):
"src/baskerville/scripts/hound_train.py",
"-o",
"tests/data/train1",
"-l",
"tests/data/train1/logs",
"tests/data/params.json",
"tests/data/tiny/hg38",
]
Expand All @@ -33,6 +35,8 @@ def test_train2(clean_data):
"src/baskerville/scripts/hound_train.py",
"-o",
"tests/data/train2",
"-l",
"tests/data/train2/logs",
"tests/data/params.json",
"tests/data/tiny/hg38",
"tests/data/tiny/mm10",
Expand Down

0 comments on commit b6f5c3c

Please sign in to comment.