diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index ae3a23f0c..cb471ca61 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -64,7 +64,7 @@ def prepare_learning(self): ) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader) self.n_updates_per_batch = 1 - self.total_steps = self.config.train.epochs * len(train_dataloader) + self.total_steps = self.config.train.epochs * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) def make_experience(self, samples, seq_length):