diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 52804f3520..79cdc225b2 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -184,7 +184,7 @@ class TrainConfig: # Fields created by `update_batch_size_info` n_gpus: int = MISSING - device_train_grad_accum: str = MISSING + device_train_grad_accum: Union[str, int] = MISSING TRAIN_CONFIG_KEYS = {field.name for field in fields(TrainConfig)}