-
Notifications
You must be signed in to change notification settings - Fork 989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong judgement of batch_sampler in prepare_data_loader #2091
Comments
This is indeed due to a recent change in accelerate. As a quick fix, could you make your |
Thanks for the reply. It indeed functions if I inherit class MySampler2(BatchSampler):
"""
BaseSampler is an iterator which could be used as a `batch_sampler` in DataLoader.
It iterates a batch of sample index each time. And BaseSampler could only handle uniformly
sampling. It works with `num_workers` in DataLoader because each worker aims to load
a batch of samples each time.
"""
def __init__(self, sampler, dataset_length: int, batch_size: int, shuffle:bool=True) -> None:
super().__init__(sampler, batch_size=batch_size, drop_last=True)
self.batch_size = batch_size
self.data_index = np.arange(dataset_length)
self.shuffle = shuffle
def __iter__(self):
output = [np.arange(10, 10+self.batch_size)] * len(self)
yield from output
def __len__(self):
return (len(self.data_index) + self.batch_size - 1) // self.batch_size
dataset = MyDataset(np.arange(10000))
sampler = Sampler(dataset) # this is pytorch's Sampler
batch_sampler2 = MySampler2(sampler, len(dataset), 32)
dataloader_pt2 = DataLoader(dataset, batch_sampler=batch_sampler2)
# Original pytorch DataLoader, no problem. Pytorch version 1.13.1
for d in dataloader_pt2:
print(d)
break
# Accelerate DataLoader, attribute not found. Accelerate verision 0.24.0
dataloader_al2 = accelerate.data_loader.prepare_data_loader(dataloader_pt2)
for d in dataloader_al2:
print(d)
break It has expected output:
|
As long as you override |
Thanks for the solution. And I wonder whether |
Yes, it’s a bug that we’ll look at fixing |
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
I defined my own batch sampler, aiming to output a batch of indices each time. It works well when it's passed as the
batch_sampler
argument totorch.util.data.DataLoader
. Meanwhile, it functions after wrapping the dataloader withaccelerate.data_loader.prepare_data_loader()
for the versionaccelerate==0.20.3
. But now I upgrade theaccelerate
to the latest version0.24.0
, it raises anAttribute Error: 'MySampler' object has no attribute 'sampler'
. The code snippets are pasted below.I have compared the source code of
accelerate/data_loader.py
in version 0.20.3 and 0.24.0. I found the main difference to this bug is thesampler_is_batch_sampler
variable inprepare_data_loade()
function. In line 718 in version 0.20.3,sampler_is_batch_sampler
is set toFalse
, while it is set assampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
in line 834 in version 0.24.0. The condition may be not right in my case, wheresampler
of dataloader is
None
andbatch_sampler
is set to my own sampler (no membersampler
in the batch sampler) instead.The output of the snippets is:
Expected behavior
The expected output should be normal, where no error raised.
The text was updated successfully, but these errors were encountered: