Skip to content

Commit

Permalink
fix(sft_trainer): total_steps calculation when distributed (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxreciprocate authored Apr 13, 2023
1 parent adb3be2 commit 92b68e4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def prepare_learning(self):
) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader)

self.n_updates_per_batch = 1
self.total_steps = self.config.train.epochs * len(train_dataloader)
self.total_steps = self.config.train.epochs * len(self.train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)

def make_experience(self, samples, seq_length):
Expand Down

0 comments on commit 92b68e4

Please sign in to comment.