diff --git a/composer/core/state.py b/composer/core/state.py index 5a44f0d005..67fb7e493e 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -30,6 +30,8 @@ from torch.utils.data import DataLoader, Dataset from torchmetrics import Metric +from composer.utils.warnings import VersionedDeprecationWarning + if version.parse(torch.__version__) >= version.parse('2.3.0'): from torch.amp.grad_scaler import GradScaler # type: ignore else: @@ -469,18 +471,18 @@ def __init__( self.fsdp_config = fsdp_config self.fsdp_auto_wrap = fsdp_auto_wrap - if self.load_fsdp_monolith_rank0_only: + if self.load_monolith_rank0_only: assert fsdp_config is not None error_message = '' if fsdp_config['use_orig_params'] == True: error_message += textwrap.dedent( - "load_fsdp_monolith_rank0_only requires fsdp_config['use_orig_params'] to be False. " - "Either set fsdp_config['use_orig_params'] = False or set load_fsdp_monolith_rank0_only = False. ", + "load_monolith_rank0_only requires fsdp_config['use_orig_params'] to be False. " + "Either set fsdp_config['use_orig_params'] = False or set load_monolith_rank0_only = False. ", ) if fsdp_config['sync_module_states'] == False: error_message += textwrap.dedent( - "load_fsdp_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. " - "Either set fsdp_config['sync_module_states'] = True or set load_fsdp_monolith_rank0_only = False. ", + "load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. " + "Either set fsdp_config['sync_module_states'] = True or set load_monolith_rank0_only = False. ", ) # Broadcast rank 0 meta check to all ranks so error can be raised on all ranks rank0_on_meta = 0 @@ -490,9 +492,9 @@ def __init__( dist.all_reduce(rank0_on_meta_tensor, reduce_operation='MAX') if rank0_on_meta_tensor.item() == 1: error_message += textwrap.dedent( - 'load_fsdp_monolith_rank0_only requires the rank 0 model to be on cpu or gpu, ' + 'load_monolith_rank0_only requires the rank 0 model to be on cpu or gpu, ' 'but detected model device as meta. Either move the model to cpu or gpu, or set ' - 'load_fsdp_monolith_rank0_only = False. ', + 'load_monolith_rank0_only = False. ', ) if error_message != '': raise ValueError(error_message) @@ -806,6 +808,16 @@ def fsdp_device_mesh(self): @property def load_fsdp_monolith_rank0_only(self): + warnings.warn( + VersionedDeprecationWarning( + 'load_fsdp_monolith_rank0_only is deprecated. Use load_monolith_rank0_only instead.', + '0.24', + ), + ) + return self.load_monolith_rank0_only + + @property + def load_monolith_rank0_only(self): return ( self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config['state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True @@ -1157,7 +1169,7 @@ def _legacy_load_optim_state(self, state_dict: Dict[str, Any]): log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}') # Loading FSDP monolith on rank 0 only requires FSDP.scatter_full_optim_state_dict # as the context manager does not seem to pass rank0_only=True for the optimizer config - if self.load_fsdp_monolith_rank0_only: + if self.load_monolith_rank0_only: optim_state_dict = _legacy_optim_state_dict_to_load( optim_state_dict=optim_state_dict, model=self.model, @@ -1249,7 +1261,7 @@ def load_model_state( missing_keys, unexpected_keys = [], [] try: # Load model if it exists - if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_fsdp_monolith_rank0_only: + if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_monolith_rank0_only: log.debug( f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}', ) @@ -1279,7 +1291,7 @@ def load_model_state( log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") # If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading - if self.load_fsdp_monolith_rank0_only: + if self.load_monolith_rank0_only: assert self.fsdp_config is not None log.info('Wrapping model with FSDP after loading model_state.') from composer.trainer.dist_strategy import prepare_fsdp_module @@ -1344,7 +1356,7 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True): log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}') # Loading FSDP monolith on rank 0 only requires FSDP.scatter_full_optim_state_dict # as the context manager does not seem to pass rank0_only=True for the optimizer config - if self.load_fsdp_monolith_rank0_only: + if self.load_monolith_rank0_only: optim_state_dict = _legacy_optim_state_dict_to_load( optim_state_dict=optim_state_dict, model=self.model, diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 9e9b79408b..26a9f87e5e 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1558,7 +1558,7 @@ def __init__( # checkpoint on rank 0 only, in which case the model be loaded before it is wrapped. # FSDP wrap if not using monolith checkpoint on rank 0 only - if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_fsdp_monolith_rank0_only: + if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_monolith_rank0_only: with reproducibility.seed_context(self.state.rank_zero_seed): prepare_fsdp_module( model, @@ -1727,8 +1727,8 @@ def __init__( self.state.run_name = run_name # FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if - # load_fsdp_monolith_rank0_only=True but no checkpoint was loaded. - if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_fsdp_monolith_rank0_only: + # load_monolith_rank0_only=True but no checkpoint was loaded. + if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_monolith_rank0_only: with reproducibility.seed_context(self.state.rank_zero_seed): prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 3b8cd69d6e..a9e7e8335b 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -899,17 +899,17 @@ def filter_func(state_dict: dict) -> None: def safe_torch_load( composer_states_filepath: Union[Path, str], map_location: str = 'cpu', - load_fsdp_monolith_rank0_only: bool = False, + load_monolith_rank0_only: bool = False, ) -> dict[str, Any]: """Load a torch checkpoint, catching errors due to backwards compatibility issues. Args: composer_states_filepath: The path to the checkpoint file. map_location: The location to load the checkpoint to. - load_fsdp_monolith_rank0_only: Whether the checkpoint is a monolith FSDP checkpoint. + load_monolith_rank0_only: Whether the checkpoint is a monolith FSDP checkpoint. """ try: - if load_fsdp_monolith_rank0_only: + if load_monolith_rank0_only: log.info( 'Loading monolith FSDP checkpoint. Only rank 0 will load and broadcast non-weight/optimizer state.', ) @@ -966,7 +966,7 @@ def _restore_checkpoint( # Now, all ranks load the checkpoint that local rank zero downloaded state_dict = safe_torch_load( composer_states_filepath=composer_states_filepath, - load_fsdp_monolith_rank0_only=state.load_fsdp_monolith_rank0_only, + load_monolith_rank0_only=state.load_monolith_rank0_only, ) if ignore_keys: # Filter provided list of key paths diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 00c1686278..d4c3cc8261 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -84,7 +84,7 @@ class FSDPConfig: sharded_ckpt_prefix_dir: str = 'ba{batch}' sync_module_states: bool = True use_orig_params: bool = False - load_fsdp_monolith_rank0_only: bool = False + load_monolith_rank0_only: bool = False save_planner: Optional[Any] = None load_planner: Optional[Any] = None @@ -286,7 +286,7 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2): @pytest.mark.gpu @world_size(2) @pytest.mark.parametrize( - 'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_fsdp_monolith_rank0_only', + 'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only', [ ['adam', False, 'amp_bf16', False, False, False], ['adamw', False, 'amp_bf16', False, False, False], @@ -305,7 +305,7 @@ def test_fsdp_full_state_dict_load( optimizer: str, save_weights_only: bool, load_weights_only: bool, - load_fsdp_monolith_rank0_only: bool, + load_monolith_rank0_only: bool, ): if autoresume: run_name = 'my-cool-autoresume-run' @@ -314,7 +314,7 @@ def test_fsdp_full_state_dict_load( save_folder = tmp_path save_filename = 'rank{rank}.pt' - fsdp_config = FSDPConfig(load_fsdp_monolith_rank0_only=load_fsdp_monolith_rank0_only) + fsdp_config = FSDPConfig(load_monolith_rank0_only=load_monolith_rank0_only) trainer1 = get_trainer( save_folder=str(save_folder),