Skip to content

Commit

Permalink
simplify conditional check as per the comment
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 committed Nov 24, 2023
1 parent 4237570 commit bdef4ac
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,12 @@ def is_fsdp_enabled():
)


def is_fsdp_enabled_and_local_dist_rank_0():
return is_fsdp_enabled() and int(os.environ.get("LOCAL_RANK", -1)) == 0
def is_local_dist_rank_0():
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and int(os.environ.get("LOCAL_RANK", -1)) == 0
)


if is_sagemaker_mp_enabled():
Expand Down Expand Up @@ -475,7 +479,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
try:
if (
is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0
) or (is_fsdp_enabled() and not is_fsdp_enabled_and_local_dist_rank_0()):
) or (is_fsdp_enabled() and not is_local_dist_rank_0()):
map_location = "meta"
else:
map_location = "cpu"
Expand Down Expand Up @@ -3903,7 +3907,18 @@ def _find_mismatched_keys(
ignore_mismatched_sizes,
)
if low_cpu_mem_usage:
if not is_fsdp_enabled() or is_fsdp_enabled_and_local_dist_rank_0():
if is_fsdp_enabled() and not is_local_dist_rank_0():
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
if not (is_quantized):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
set_module_quantized_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
Expand All @@ -3921,17 +3936,6 @@ def _find_mismatched_keys(
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
else:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
if not (is_quantized):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
set_module_quantized_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

Expand Down

0 comments on commit bdef4ac

Please sign in to comment.