diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index 83a9a7d8ea..e7521bc343 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -3,7 +3,7 @@ """Dataloader builder utilities.""" -from typing import Any, Dict +from typing import Any, Dict, Union from composer import DataSpec from transformers import PreTrainedTokenizerBase @@ -19,7 +19,7 @@ def build_dataloader( cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, - device_batch_size: int, + device_batch_size: Union[int, float], ) -> DataSpec: """Builds a dataloader from a config.