From a41581cc8e77219a6d9e8a214d8c912e3bccdfdc Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 2 Jan 2024 21:20:53 +0000 Subject: [PATCH] add check --- llmfoundry/utils/config_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 6680154e87..4963dcaa24 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -120,6 +120,13 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set defaults for mixed initialization fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) + # Always set `sync_module_states` to True when using hybrid sharding + if fsdp_config.get('sharding_strategy', 'FULL_SHARD') in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] \ + and not fsdp_config.get('sync_module_states', False): + warnings.warn(( + 'Setting `sync_module_states = True` for FSDP. This is required ' + 'when using hybrid sharding.')) + fsdp_config['sync_module_states'] = True # no mixed precision needed for weights when they're already 16 bits master_dtype = model_cfg.get('master_weights_dtype')