Skip to content

Commit

Permalink
replace load_fsdp_monolith_ with load_monolith_ (#3288)
Browse files Browse the repository at this point in the history
* replace load_monolith_ with load_fsdp_monolith_

* change load_fsdp_monolith_rank0_only to load_monolith_rank0_only

* lint

* added deprecation warning for old thingy

* versioned deprecation warning

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
milocress and mvpatel2000 authored May 15, 2024
1 parent 4806293 commit 423706b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 22 deletions.
34 changes: 23 additions & 11 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric

from composer.utils.warnings import VersionedDeprecationWarning

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import GradScaler # type: ignore
else:
Expand Down Expand Up @@ -469,18 +471,18 @@ def __init__(
self.fsdp_config = fsdp_config
self.fsdp_auto_wrap = fsdp_auto_wrap

if self.load_fsdp_monolith_rank0_only:
if self.load_monolith_rank0_only:
assert fsdp_config is not None
error_message = ''
if fsdp_config['use_orig_params'] == True:
error_message += textwrap.dedent(
"load_fsdp_monolith_rank0_only requires fsdp_config['use_orig_params'] to be False. "
"Either set fsdp_config['use_orig_params'] = False or set load_fsdp_monolith_rank0_only = False. ",
"load_monolith_rank0_only requires fsdp_config['use_orig_params'] to be False. "
"Either set fsdp_config['use_orig_params'] = False or set load_monolith_rank0_only = False. ",
)
if fsdp_config['sync_module_states'] == False:
error_message += textwrap.dedent(
"load_fsdp_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. "
"Either set fsdp_config['sync_module_states'] = True or set load_fsdp_monolith_rank0_only = False. ",
"load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. "
"Either set fsdp_config['sync_module_states'] = True or set load_monolith_rank0_only = False. ",
)
# Broadcast rank 0 meta check to all ranks so error can be raised on all ranks
rank0_on_meta = 0
Expand All @@ -490,9 +492,9 @@ def __init__(
dist.all_reduce(rank0_on_meta_tensor, reduce_operation='MAX')
if rank0_on_meta_tensor.item() == 1:
error_message += textwrap.dedent(
'load_fsdp_monolith_rank0_only requires the rank 0 model to be on cpu or gpu, '
'load_monolith_rank0_only requires the rank 0 model to be on cpu or gpu, '
'but detected model device as meta. Either move the model to cpu or gpu, or set '
'load_fsdp_monolith_rank0_only = False. ',
'load_monolith_rank0_only = False. ',
)
if error_message != '':
raise ValueError(error_message)
Expand Down Expand Up @@ -806,6 +808,16 @@ def fsdp_device_mesh(self):

@property
def load_fsdp_monolith_rank0_only(self):
warnings.warn(
VersionedDeprecationWarning(
'load_fsdp_monolith_rank0_only is deprecated. Use load_monolith_rank0_only instead.',
'0.24',
),
)
return self.load_monolith_rank0_only

@property
def load_monolith_rank0_only(self):
return (
self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config['state_dict_type'] == 'full' and
self.fsdp_config['load_monolith_rank0_only'] == True
Expand Down Expand Up @@ -1157,7 +1169,7 @@ def _legacy_load_optim_state(self, state_dict: Dict[str, Any]):
log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}')
# Loading FSDP monolith on rank 0 only requires FSDP.scatter_full_optim_state_dict
# as the context manager does not seem to pass rank0_only=True for the optimizer config
if self.load_fsdp_monolith_rank0_only:
if self.load_monolith_rank0_only:
optim_state_dict = _legacy_optim_state_dict_to_load(
optim_state_dict=optim_state_dict,
model=self.model,
Expand Down Expand Up @@ -1249,7 +1261,7 @@ def load_model_state(
missing_keys, unexpected_keys = [], []
try:
# Load model if it exists
if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_fsdp_monolith_rank0_only:
if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_monolith_rank0_only:
log.debug(
f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}',
)
Expand Down Expand Up @@ -1279,7 +1291,7 @@ def load_model_state(
log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")

# If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading
if self.load_fsdp_monolith_rank0_only:
if self.load_monolith_rank0_only:
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
from composer.trainer.dist_strategy import prepare_fsdp_module
Expand Down Expand Up @@ -1344,7 +1356,7 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True):
log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}')
# Loading FSDP monolith on rank 0 only requires FSDP.scatter_full_optim_state_dict
# as the context manager does not seem to pass rank0_only=True for the optimizer config
if self.load_fsdp_monolith_rank0_only:
if self.load_monolith_rank0_only:
optim_state_dict = _legacy_optim_state_dict_to_load(
optim_state_dict=optim_state_dict,
model=self.model,
Expand Down
6 changes: 3 additions & 3 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,7 +1558,7 @@ def __init__(
# checkpoint on rank 0 only, in which case the model be loaded before it is wrapped.

# FSDP wrap if not using monolith checkpoint on rank 0 only
if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_fsdp_monolith_rank0_only:
if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_monolith_rank0_only:
with reproducibility.seed_context(self.state.rank_zero_seed):
prepare_fsdp_module(
model,
Expand Down Expand Up @@ -1727,8 +1727,8 @@ def __init__(
self.state.run_name = run_name

# FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if
# load_fsdp_monolith_rank0_only=True but no checkpoint was loaded.
if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_fsdp_monolith_rank0_only:
# load_monolith_rank0_only=True but no checkpoint was loaded.
if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_monolith_rank0_only:
with reproducibility.seed_context(self.state.rank_zero_seed):
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)

Expand Down
8 changes: 4 additions & 4 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,17 +899,17 @@ def filter_func(state_dict: dict) -> None:
def safe_torch_load(
composer_states_filepath: Union[Path, str],
map_location: str = 'cpu',
load_fsdp_monolith_rank0_only: bool = False,
load_monolith_rank0_only: bool = False,
) -> dict[str, Any]:
"""Load a torch checkpoint, catching errors due to backwards compatibility issues.
Args:
composer_states_filepath: The path to the checkpoint file.
map_location: The location to load the checkpoint to.
load_fsdp_monolith_rank0_only: Whether the checkpoint is a monolith FSDP checkpoint.
load_monolith_rank0_only: Whether the checkpoint is a monolith FSDP checkpoint.
"""
try:
if load_fsdp_monolith_rank0_only:
if load_monolith_rank0_only:
log.info(
'Loading monolith FSDP checkpoint. Only rank 0 will load and broadcast non-weight/optimizer state.',
)
Expand Down Expand Up @@ -966,7 +966,7 @@ def _restore_checkpoint(
# Now, all ranks load the checkpoint that local rank zero downloaded
state_dict = safe_torch_load(
composer_states_filepath=composer_states_filepath,
load_fsdp_monolith_rank0_only=state.load_fsdp_monolith_rank0_only,
load_monolith_rank0_only=state.load_monolith_rank0_only,
)
if ignore_keys:
# Filter provided list of key paths
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class FSDPConfig:
sharded_ckpt_prefix_dir: str = 'ba{batch}'
sync_module_states: bool = True
use_orig_params: bool = False
load_fsdp_monolith_rank0_only: bool = False
load_monolith_rank0_only: bool = False
save_planner: Optional[Any] = None
load_planner: Optional[Any] = None

Expand Down Expand Up @@ -286,7 +286,7 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2):
@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize(
'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_fsdp_monolith_rank0_only',
'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only',
[
['adam', False, 'amp_bf16', False, False, False],
['adamw', False, 'amp_bf16', False, False, False],
Expand All @@ -305,7 +305,7 @@ def test_fsdp_full_state_dict_load(
optimizer: str,
save_weights_only: bool,
load_weights_only: bool,
load_fsdp_monolith_rank0_only: bool,
load_monolith_rank0_only: bool,
):
if autoresume:
run_name = 'my-cool-autoresume-run'
Expand All @@ -314,7 +314,7 @@ def test_fsdp_full_state_dict_load(
save_folder = tmp_path
save_filename = 'rank{rank}.pt'

fsdp_config = FSDPConfig(load_fsdp_monolith_rank0_only=load_fsdp_monolith_rank0_only)
fsdp_config = FSDPConfig(load_monolith_rank0_only=load_monolith_rank0_only)

trainer1 = get_trainer(
save_folder=str(save_folder),
Expand Down

0 comments on commit 423706b

Please sign in to comment.