From 805cc671293879d4b6bd83e2c26685597e98f2f0 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 13 Dec 2024 11:49:18 +0800 Subject: [PATCH] add trainer_state --- paddlenlp/trainer/trainer.py | 12 +++++++----- paddlenlp/trainer/trainer_callback.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 630c91e3b362..77d2ecd92e71 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -973,7 +973,7 @@ def _inner_training_loop( self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() - self.state.trained_samples = 0 + self.state.consumed_samples = 0 self.control = self.callback_handler.on_train_begin(args, self.state, self.control) @@ -1049,9 +1049,10 @@ def _inner_training_loop( self._skip_steps_since_last_logged += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch - self.state.trained_samples = ( - (epoch * steps_in_epoch + step + 1) + self.state.consumed_samples = ( + self.state.global_step * args.per_device_train_batch_size + * args.gradient_accumulation_steps * args.dataset_world_size ) @@ -1236,9 +1237,10 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch - self.state.trained_samples = ( - (epoch * steps_in_epoch + step + 1) + self.state.consumed_samples = ( + self.state.global_step * args.per_device_train_batch_size + * args.gradient_accumulation_steps * args.dataset_world_size ) self.control = self.callback_handler.on_step_end(args, self.state, self.control) diff --git a/paddlenlp/trainer/trainer_callback.py b/paddlenlp/trainer/trainer_callback.py index beaed9844e3a..532b46d52df5 100644 --- a/paddlenlp/trainer/trainer_callback.py +++ b/paddlenlp/trainer/trainer_callback.py @@ -85,7 +85,7 @@ class TrainerState: epoch: Optional[float] = None global_step: int = 0 - trained_samples: int = 0 + consumed_samples: int = 0 max_steps: int = 0 num_train_epochs: int = 0 total_flos: float = 0