From d8ba05865a15360507f7f82cb0c61516fba08dcf Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Tue, 26 Dec 2023 12:07:40 -0800 Subject: [PATCH] fix config Signed-off-by: Yi Dong --- examples/nlp/gpt/train_gpt_sft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/nlp/gpt/train_gpt_sft.py b/examples/nlp/gpt/train_gpt_sft.py index 4382ff597..fa51b942e 100644 --- a/examples/nlp/gpt/train_gpt_sft.py +++ b/examples/nlp/gpt/train_gpt_sft.py @@ -158,10 +158,10 @@ def main(cfg) -> None: if cfg.model.data.get("sample", False): # if it is negative, num_samples is None - if cfg.sft_trainer.max_steps < 0: + if cfg.trainer.sft.max_steps < 0: num_samples = None else: - num_samples = cfg.sft_trainer.max_steps * train_data_cfg.global_batch_size + num_samples = cfg.trainer.sft.max_steps * train_data_cfg.global_batch_size else: num_samples = None train_ds = build_sft_dataset( @@ -173,7 +173,7 @@ def main(cfg) -> None: special_tokens=cfg.model.data.chat_prompt_tokens, ) if cfg.model.data.get("sample", False): - num_samples = cfg.sft_trainer.limit_val_batches * val_data_cfg.global_batch_size + num_samples = cfg.trainer.sft.limit_val_batches * val_data_cfg.global_batch_size else: num_samples = None validation_ds = build_sft_dataset(