From 92b68e4d8c5d59e6ba25d12fd9acfe10287be689 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 13 Apr 2023 03:10:00 +0300 Subject: [PATCH] fix(sft_trainer): `total_steps` calculation when distributed (#432) --- trlx/trainer/accelerate_sft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index ae3a23f0c..cb471ca61 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -64,7 +64,7 @@ def prepare_learning(self): ) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader) self.n_updates_per_batch = 1 - self.total_steps = self.config.train.epochs * len(train_dataloader) + self.total_steps = self.config.train.epochs * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) def make_experience(self, samples, seq_length):