Skip to content

Commit

Permalink
fix config (#67)
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Dong <[email protected]>
  • Loading branch information
yidong72 authored Dec 26, 2023
1 parent adafc6e commit 6b3464f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/nlp/gpt/train_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ def main(cfg) -> None:

if cfg.model.data.get("sample", False):
# if it is negative, num_samples is None
if cfg.sft_trainer.max_steps < 0:
if cfg.trainer.sft.max_steps < 0:
num_samples = None
else:
num_samples = cfg.sft_trainer.max_steps * train_data_cfg.global_batch_size
num_samples = cfg.trainer.sft.max_steps * train_data_cfg.global_batch_size
else:
num_samples = None
train_ds = build_sft_dataset(
Expand All @@ -173,7 +173,7 @@ def main(cfg) -> None:
special_tokens=cfg.model.data.chat_prompt_tokens,
)
if cfg.model.data.get("sample", False):
num_samples = cfg.sft_trainer.limit_val_batches * val_data_cfg.global_batch_size
num_samples = cfg.trainer.sft.limit_val_batches * val_data_cfg.global_batch_size
else:
num_samples = None
validation_ds = build_sft_dataset(
Expand Down

0 comments on commit 6b3464f

Please sign in to comment.