diff --git a/ffn/jax/train.py b/ffn/jax/train.py index 4149747..581b74a 100644 --- a/ffn/jax/train.py +++ b/ffn/jax/train.py @@ -338,7 +338,9 @@ def _get_tf_writer(writers) -> metric_writers.SummaryWriter | None: # pylint:enable=protected-access -def _get_ocp_args(train_iter: DataIterator) -> DataIterator: +def _get_ocp_args( + train_iter: DataIterator, restore: bool = True +) -> DataIterator: if isinstance(train_iter, tf.data.Iterator): return DatasetArgs(train_iter) @@ -346,7 +348,7 @@ def _get_ocp_args(train_iter: DataIterator) -> DataIterator: def _make_ckpt_args(state, train_iter: DataIterator) -> ocp.args.CheckpointArgs: return ocp.args.Composite( train_state=ocp.args.StandardSave(state), - train_iter=_get_ocp_args(train_iter), + train_iter=_get_ocp_args(train_iter, restore=False), )