diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 652bbc0c4b5..a3d7ecdf827 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -328,6 +328,7 @@ def __len__(self): assert isinstance( dl.batch_sampler, CustomBatchSampler ), "Custom sampler was changed after calling `prepare_data_loader`" + state._reset_state() def mock_training(length, batch_size, generator):