Skip to content

Commit

Permalink
add trainer state
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Dec 12, 2024
1 parent 70e80c7 commit 6655e12
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
11 changes: 11 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ def _inner_training_loop(
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
self.state.trained_samples = 0

self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

Expand Down Expand Up @@ -1048,6 +1049,11 @@ def _inner_training_loop(
self._skip_steps_since_last_logged += 1

self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.state.trained_samples = (
(epoch * steps_in_epoch + step + 1)
* args.per_device_train_batch_size
* args.dataset_world_size
)

if self.state.global_step == 1 and self.args.logging_first_step:
self.control.should_log = True
Expand Down Expand Up @@ -1230,6 +1236,11 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):

self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.state.trained_samples = (
(epoch * steps_in_epoch + step + 1)
* args.per_device_train_batch_size
* args.dataset_world_size
)
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs)
self._print_timer()
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class TrainerState:

epoch: Optional[float] = None
global_step: int = 0
trained_samples: int = 0
max_steps: int = 0
num_train_epochs: int = 0
total_flos: float = 0
Expand Down

0 comments on commit 6655e12

Please sign in to comment.