From 95b718b115a652c5f490fbddb2d8933797230f12 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 7 Nov 2024 12:27:25 -0500 Subject: [PATCH 1/2] add gradient accumulation steps tests for fsdp --- tests/fsdp/test_fsdp.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 7e14cc8c9e6fc9..74a3bfe04b7506 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -224,6 +224,18 @@ def test_basic_run(self, sharding_strategy, dtype): cmd = launcher + script + args + fsdp_args execute_subprocess_async(cmd, env=self.get_env()) + @parameterized.expand(params, name_func=_parameterized_custom_name_func) + @require_torch_multi_accelerator + @slow + def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype): + launcher = get_launcher(distributed=True, use_accelerate=False) + output_dir = self.get_auto_remove_tmp_dir() + args = self.get_base_args(output_dir, 1, 50).split() + [f"--{dtype}", "--gradient_accumulation_steps", "2"] + fsdp_args = ["--fsdp", f"{sharding_strategy} auto_wrap", "--fsdp_transformer_layer_cls_to_wrap", "BertLayer"] + script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"] + cmd = launcher + script + args + fsdp_args + execute_subprocess_async(cmd, env=self.get_env()) + @parameterized.expand(dtypes) @require_torch_multi_accelerator @slow From edd102fc25fe42a2d2dbd2c7c2b6ca2d40353f09 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 7 Nov 2024 12:28:27 -0500 Subject: [PATCH 2/2] invert no_sync context to fix training for fsdp --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d41b7181be6334..fec4bc4d6b283c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2488,7 +2488,7 @@ def _inner_training_loop( # We explicitly want to avoid relying on `accelerator.accumulate` for generation training context = ( functools.partial(self.accelerator.no_sync, model=model) - if i == len(batch_samples) - 1 + if i != len(batch_samples) - 1 else contextlib.nullcontext ) with context():