diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 44d2db1f98..52804f3520 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -452,13 +452,20 @@ def calculate_batch_size_info( return device_batch_size, device_microbatch_size, device_grad_accum -def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: - data_replication_degree = 1 - device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( - cfg['global_train_batch_size'], - cfg['device_train_microbatch_size'], - data_replication_degree=data_replication_degree, - ) +def update_config_with_batch_size_info( + cfg: Dict[str, Any], + device_train_batch_size: Union[int, float], + device_train_microbatch_size: Union[int, float, Literal['auto']], + device_train_grad_accum: Union[int, Literal['auto']], +) -> Dict[str, Any]: + """Update the config with batch size information. + + Args: + cfg (Dict[str, Any]): The config to update. + + Returns: + Dict[str, Any]: The updated config. + """ cfg['n_gpus'] = dist.get_world_size() cfg['device_train_batch_size'] = device_train_batch_size cfg['device_train_microbatch_size'] = device_train_microbatch_size @@ -473,6 +480,22 @@ def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: return cfg +def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: + data_replication_degree = 1 + device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( + cfg['global_train_batch_size'], + cfg['device_train_microbatch_size'], + data_replication_degree=data_replication_degree, + ) + cfg = update_config_with_batch_size_info( + cfg, + device_train_batch_size, + device_train_microbatch_size, + device_train_grad_accum, + ) + return cfg + + def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]): # Restrict model init_device to 'meta' and 'cpu', # using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py new file mode 100644 index 0000000000..1b78d0077b --- /dev/null +++ b/tests/utils/test_config_utils.py @@ -0,0 +1,15 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.utils.config_utils import update_config_with_batch_size_info + + +def test_update_config_with_batch_size_info(): + config = {} + config = update_config_with_batch_size_info(config, 1, 2, 3) + + assert config['n_gpus'] == 1 + assert config['device_train_batch_size'] == 1 + assert config['device_train_microbatch_size'] == 2 + assert config['device_train_grad_accum'] == 3 + assert config['device_eval_batch_size'] == 2