Skip to content

Commit

Permalink
FSDP grad accum fix (#34645)
Browse files Browse the repository at this point in the history
* add gradient accumulation steps tests for fsdp

* invert no_sync context to fix training for fsdp
  • Loading branch information
winglian authored Nov 15, 2024
1 parent 52ea4aa commit b0c0ba7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2488,7 +2488,7 @@ def _inner_training_loop(
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
Expand Down
12 changes: 12 additions & 0 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,18 @@ def test_basic_run(self, sharding_strategy, dtype):
cmd = launcher + script + args + fsdp_args
execute_subprocess_async(cmd, env=self.get_env())

@parameterized.expand(params, name_func=_parameterized_custom_name_func)
@require_torch_multi_accelerator
@slow
def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype):
launcher = get_launcher(distributed=True, use_accelerate=False)
output_dir = self.get_auto_remove_tmp_dir()
args = self.get_base_args(output_dir, 1, 50).split() + [f"--{dtype}", "--gradient_accumulation_steps", "2"]
fsdp_args = ["--fsdp", f"{sharding_strategy} auto_wrap", "--fsdp_transformer_layer_cls_to_wrap", "BertLayer"]
script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"]
cmd = launcher + script + args + fsdp_args
execute_subprocess_async(cmd, env=self.get_env())

@parameterized.expand(dtypes)
@require_torch_multi_accelerator
@slow
Expand Down

0 comments on commit b0c0ba7

Please sign in to comment.