Skip to content

Commit

Permalink
Move out logic for getting steps_per_epoch as a free function
Browse files Browse the repository at this point in the history
  • Loading branch information
bhky committed May 13, 2022
1 parent 1a4d278 commit fc49c1a
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions training/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def compile_model(model: Model) -> None:
)


def get_steps_per_epoch(
landmark_dict: Dict[str, Sequence[Sequence[float]]]
) -> int:
# Kind of arbitrary here.
mean_data_size = int(np.mean([len(v) for v in landmark_dict.values()]))
steps_per_epoch = int(mean_data_size * 0.7)
return steps_per_epoch


def train_and_save_weights(
landmark_dict: Dict[str, List[List[float]]],
model: Model,
Expand All @@ -139,10 +148,6 @@ def train_and_save_weights(
)
ds_train = ds_train.batch(16).prefetch(tf.data.AUTOTUNE)

# Kind of arbitrary here.
mean_data_size = int(np.mean([len(v) for v in landmark_dict.values()]))
steps_per_epoch = int(mean_data_size * 0.7)

callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=weights_path, monitor="loss", mode="min",
Expand All @@ -156,7 +161,7 @@ def train_and_save_weights(
model.fit(
ds_train,
epochs=500,
steps_per_epoch=steps_per_epoch,
steps_per_epoch=get_steps_per_epoch(landmark_dict),
callbacks=callbacks,
verbose=1,
)
Expand Down

0 comments on commit fc49c1a

Please sign in to comment.