Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring Trainer, adds save_only_model arg and simplifying FSDP integration #27652

Merged
merged 13 commits into from
Nov 24, 2023
33 changes: 24 additions & 9 deletions docs/source/en/main_classes/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,7 @@ To read more about it and the benefits, check out the [Fully Sharded Data Parall
We have integrated the latest PyTorch's Fully Sharded Data Parallel (FSDP) training feature.
All you need to do is enable it through the config.

**Required PyTorch version for FSDP support**: PyTorch Nightly (or 1.12.0 if you read this after it has been released)
as the model saving with FSDP activated is only available with recent fixes.
**Required PyTorch version for FSDP support**: PyTorch >=2.1.0

**Usage**:

Expand All @@ -440,6 +439,8 @@ as the model saving with FSDP activated is only available with recent fixes.
- SHARD_GRAD_OP : Shards optimizer states + gradients across data parallel workers/GPUs.
For this, add `--fsdp shard_grad_op` to the command line arguments.
- NO_SHARD : No sharding. For this, add `--fsdp no_shard` to the command line arguments.
- HYBRID_SHARD : No sharding. For this, add `--fsdp hybrid_shard` to the command line arguments.
- HYBRID_SHARD_ZERO2 : No sharding. For this, add `--fsdp hybrid_shard_zero2` to the command line arguments.
- To offload the parameters and gradients to the CPU,
add `--fsdp "full_shard offload"` or `--fsdp "shard_grad_op offload"` to the command line arguments.
- To automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`,
Expand All @@ -449,25 +450,39 @@ as the model saving with FSDP activated is only available with recent fixes.
- Remaining FSDP config is passed via `--fsdp_config <path_to_fsdp_config.json>`. It is either a location of
FSDP json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.
- If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy.
- For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
- For transformer based auto wrap policy, it is recommended to specify `transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit.
Therefore, use this for transformer based models.
- For size based auto wrap policy, please add `fsdp_min_num_params` in the config file.
- For size based auto wrap policy, please add `min_num_params` in the config file.
It specifies FSDP's minimum number of parameters for auto wrapping.
- `fsdp_backward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters.
- `backward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters.
`backward_pre` and `backward_pos` are available options.
For more information refer `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`
- `fsdp_forward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters.
- `forward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters.
If `"True"`, FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass.
- `limit_all_gathers` can be specified in the config file.
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers.
- `activation_checkpointing` can be specified in the config file.
If `"True"`, FSDP activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time
for reduced memory usage.
- `use_orig_params` can be specified in the config file.
If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. This also enables to have different optimizer param groups. This should be `True` when creating optimizer object before preparing/wrapping the model with FSDP.
Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019).

**Saving and loading**
Saving entire intermediate checkpoints using `FULL_STATE_DICT` state_dict_type with CPU offloading on rank 0 takes a lot of time and often results in NCCL Timeout errors due to indefinite hanging during broadcasting. However, at the end of training, we want the whole model state dict instead of the sharded state dict which is only compatible with FSDP. Use `SHARDED_STATE_DICT` (default) state_dict_type to save the intermediate checkpoints and optimizer states in this format recommended by the PyTorch team.

Saving the final checkpoint in transformers format using default `safetensors` format requires below changes.
```python
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

trainer.save_model(script_args.output_dir)
```

**Few caveats to be aware of**
- it is incompatible with `generate`, thus is incompatible with `--predict_with_generate`
Expand All @@ -492,15 +507,15 @@ Pass `--fsdp "full shard"` along with following changes to be made in `--fsdp_co
https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py).
- `xla_fsdp_grad_ckpt`. When `True`, uses gradient checkpointing over each nested XLA FSDP wrapped layer.
This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through
`fsdp_min_num_params` or `fsdp_transformer_layer_cls_to_wrap`.
`min_num_params` or `transformer_layer_cls_to_wrap`.
- You can either use transformer based auto wrap policy or size based auto wrap policy.
- For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
- For transformer based auto wrap policy, it is recommended to specify `transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit.
Therefore, use this for transformer based models.
- For size based auto wrap policy, please add `fsdp_min_num_params` in the config file.
- For size based auto wrap policy, please add `min_num_params` in the config file.
It specifies FSDP's minimum number of parameters for auto wrapping.


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow>=10.0.1,<=15.0",
"accelerate>=0.20.3",
"accelerate>=0.21.0",
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
"beautifulsoup4",
"codecarbon==1.2.0",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow>=10.0.1,<=15.0",
"accelerate": "accelerate>=0.20.3",
"accelerate": "accelerate>=0.21.0",
"av": "av==9.2.0",
"beautifulsoup4": "beautifulsoup4",
"codecarbon": "codecarbon==1.2.0",
Expand Down
39 changes: 21 additions & 18 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,12 @@ def is_fsdp_enabled():
)


def is_fsdp_enabled_and_dist_rank_0():
return is_fsdp_enabled() and int(os.environ.get("LOCAL_RANK", -1)) == 0
def is_local_dist_rank_0():
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and int(os.environ.get("LOCAL_RANK", -1)) == 0
)


if is_sagemaker_mp_enabled():
Expand Down Expand Up @@ -474,13 +478,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
return safe_load_file(checkpoint_file)
try:
if (
(is_deepspeed_zero3_enabled() or is_fsdp_enabled())
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
):
is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0
) or (is_fsdp_enabled() and not is_local_dist_rank_0()):
map_location = "meta"
else:
map_location = "cpu"

return torch.load(checkpoint_file, map_location=map_location)
except Exception as e:
try:
Expand Down Expand Up @@ -3904,7 +3907,18 @@ def _find_mismatched_keys(
ignore_mismatched_sizes,
)
if low_cpu_mem_usage:
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
if is_fsdp_enabled() and not is_local_dist_rank_0():
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
if not (is_quantized):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
set_module_quantized_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
Expand All @@ -3922,17 +3936,6 @@ def _find_mismatched_keys(
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
else:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
if not (is_quantized):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
set_module_quantized_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

Expand Down
Loading
Loading