diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 8fd76b3fe56..01a348f82bd 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -539,6 +539,10 @@ def total_batch_size(self): def total_dataset_length(self): return self._loader.total_dataset_length + @property + def batch_sampler(self): + return self._loader.batch_sampler + class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): """