Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP grad accum fix #34645

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2488,7 +2488,7 @@ def _inner_training_loop(
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
Expand Down
12 changes: 12 additions & 0 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,18 @@ def test_basic_run(self, sharding_strategy, dtype):
cmd = launcher + script + args + fsdp_args
execute_subprocess_async(cmd, env=self.get_env())

@parameterized.expand(params, name_func=_parameterized_custom_name_func)
@require_torch_multi_accelerator
@slow
def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype):
launcher = get_launcher(distributed=True, use_accelerate=False)
output_dir = self.get_auto_remove_tmp_dir()
args = self.get_base_args(output_dir, 1, 50).split() + [f"--{dtype}", "--gradient_accumulation_steps", "2"]
fsdp_args = ["--fsdp", f"{sharding_strategy} auto_wrap", "--fsdp_transformer_layer_cls_to_wrap", "BertLayer"]
script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"]
cmd = launcher + script + args + fsdp_args
execute_subprocess_async(cmd, env=self.get_env())

@parameterized.expand(dtypes)
@require_torch_multi_accelerator
@slow
Expand Down
Loading