diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 0d43274a633..6bf5d803f94 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -653,9 +653,7 @@ def main(): dl_preparation_check() if state.distributed_type != DistributedType.XLA: central_dl_preparation_check() - # Skip this test because the TorchXLA's MpDeviceLoaderWrapper does not - # have the 'batch_sampler' attribute. - custom_sampler_check() + custom_sampler_check() # Trainings are not exactly the same in DeepSpeed and CPU mode if state.distributed_type == DistributedType.DEEPSPEED: