Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 committed Oct 12, 2023
1 parent 1799fbc commit 6618083
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/source/usage_guides/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ all-gather while executing in the forward pass. only use with Static graphs.
Useful in cases such as parameter-efficient fine-tuning.
Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019)
`CPU RAM Efficient Model loading`: If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for 🤗 Transformers models. When using this, `Sync Module States` needs to True.
`CPU RAM Efficient Model loading`: If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for 🤗 Transformers models. When using this, `Sync Module States` needs to True else all the processes expect the main process would have random empty weights leading to unexpected behaviour during training.
`Sync Module States`: If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0
```
Expand Down
8 changes: 3 additions & 5 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
from ..utils.other import is_port_in_use, merge_dicts
from .dataclasses import DistributedType, SageMakerDistributedType
from .environment import str_to_bool


def _filter_args(args, parser, default_args=[]):
Expand Down Expand Up @@ -175,6 +174,9 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:

if args.use_fsdp:
current_env["ACCELERATE_USE_FSDP"] = "true"
if args.fsdp_cpu_ram_efficient_loading and not args.fsdp_sync_module_states:
raise ValueError("When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`")

current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy)
current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower()
current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params)
Expand All @@ -190,10 +192,6 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
current_env["FSDP_USE_ORIG_PARAMS"] = str(args.fsdp_use_orig_params).lower()
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
if str_to_bool(current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"]) and not str_to_bool(
current_env["FSDP_SYNC_MODULE_STATES"]
):
raise ValueError("When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`")

if args.use_megatron_lm:
prefix = "MEGATRON_LM_"
Expand Down

0 comments on commit 6618083

Please sign in to comment.