From c2d8e245e9fa603b29986cb3b677cb0d44b41f6a Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 15:03:59 -0400 Subject: [PATCH] Fix issue with tests (#2111) --- src/accelerate/test_utils/scripts/test_script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index e5acff0e7ed..cd2c1302d46 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -339,7 +339,7 @@ def mock_training(length, batch_size, generator): set_seed(42) generator.manual_seed(42) train_set = RegressionDataset(length=length, seed=42) - if AcceleratorState().num_processes > 1: + if not ((AcceleratorState().num_processes == 1) and (AcceleratorState().device.type == "cpu")): # The SeedableRandomSampler is needed during distributed setups # for full reproducability across processes with the `DataLoader` sampler = SeedableRandomSampler(