From 11e2e99cfca3afe1cefe02111f40665b692b86fb Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 23 Oct 2023 08:12:26 -0400 Subject: [PATCH] Let iterable dataset shard have a len (#2066) --- src/accelerate/data_loader.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 3390b8cb803..2e9b0906ca1 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -271,6 +271,13 @@ def __init__( self.process_index = process_index self.split_batches = split_batches + def __len__(self): + # We will just raise the downstream error if the underlying dataset is not sized + if self.drop_last: + return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size + else: + return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size + def __iter__(self): real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes) process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size