From 6655e12075872cc79625ce2e182ab47b4942e345 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 12 Dec 2024 15:01:01 +0800 Subject: [PATCH] add trainer state --- paddlenlp/trainer/trainer.py | 11 +++++++++++ paddlenlp/trainer/trainer_callback.py | 1 + 2 files changed, 12 insertions(+) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 57c655736f25..630c91e3b362 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -973,6 +973,7 @@ def _inner_training_loop( self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() + self.state.trained_samples = 0 self.control = self.callback_handler.on_train_begin(args, self.state, self.control) @@ -1048,6 +1049,11 @@ def _inner_training_loop( self._skip_steps_since_last_logged += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.state.trained_samples = ( + (epoch * steps_in_epoch + step + 1) + * args.per_device_train_batch_size + * args.dataset_world_size + ) if self.state.global_step == 1 and self.args.logging_first_step: self.control.should_log = True @@ -1230,6 +1236,11 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.state.trained_samples = ( + (epoch * steps_in_epoch + step + 1) + * args.per_device_train_batch_size + * args.dataset_world_size + ) self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) self._print_timer() diff --git a/paddlenlp/trainer/trainer_callback.py b/paddlenlp/trainer/trainer_callback.py index b263c7930daf..beaed9844e3a 100644 --- a/paddlenlp/trainer/trainer_callback.py +++ b/paddlenlp/trainer/trainer_callback.py @@ -85,6 +85,7 @@ class TrainerState: epoch: Optional[float] = None global_step: int = 0 + trained_samples: int = 0 max_steps: int = 0 num_train_epochs: int = 0 total_flos: float = 0