diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 5ab148bbe8..5c1ec9114a 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -100,7 +100,7 @@ class TrainConfig: optimizer: Dict[str, Any] = MISSING scheduler: Dict[str, Any] = MISSING train_loader: Dict[str, Any] = MISSING - device_train_batch_size: int = MISSING + device_train_batch_size: Union[int, float] = MISSING device_eval_batch_size: int = MISSING max_duration: Union[int, str] = MISSING eval_interval: Union[int, str] = MISSING @@ -183,7 +183,6 @@ class TrainConfig: # Fields created by `update_batch_size_info` n_gpus: int = MISSING - device_train_batch_size: int = MISSING device_train_grad_accum: str = MISSING