diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 96847c5409..644eb1417d 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -136,7 +136,7 @@ def get_num_tokens_in_batch(batch: Batch) -> int: def get_text_collator( cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - dataset_batch_size: int = -1, + dataset_batch_size: int, ) -> Tuple[Union[transformers.DataCollatorForLanguageModeling, ConcatenatedSequenceCollatorWrapper], int]: eos_token_id = cfg.dataset.get('eos_token_id', None)