From 6b3464fab18aa263cbe7790c1b06744bf795fb4d Mon Sep 17 00:00:00 2001 From: Yi Dong <43824965+yidong72@users.noreply.github.com> Date: Tue, 26 Dec 2023 16:21:57 -0500 Subject: [PATCH] fix config (#67) Signed-off-by: Yi Dong --- examples/nlp/gpt/train_gpt_sft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/nlp/gpt/train_gpt_sft.py b/examples/nlp/gpt/train_gpt_sft.py index 4382ff597..fa51b942e 100644 --- a/examples/nlp/gpt/train_gpt_sft.py +++ b/examples/nlp/gpt/train_gpt_sft.py @@ -158,10 +158,10 @@ def main(cfg) -> None: if cfg.model.data.get("sample", False): # if it is negative, num_samples is None - if cfg.sft_trainer.max_steps < 0: + if cfg.trainer.sft.max_steps < 0: num_samples = None else: - num_samples = cfg.sft_trainer.max_steps * train_data_cfg.global_batch_size + num_samples = cfg.trainer.sft.max_steps * train_data_cfg.global_batch_size else: num_samples = None train_ds = build_sft_dataset( @@ -173,7 +173,7 @@ def main(cfg) -> None: special_tokens=cfg.model.data.chat_prompt_tokens, ) if cfg.model.data.get("sample", False): - num_samples = cfg.sft_trainer.limit_val_batches * val_data_cfg.global_batch_size + num_samples = cfg.trainer.sft.limit_val_batches * val_data_cfg.global_batch_size else: num_samples = None validation_ds = build_sft_dataset(