diff --git a/tutorials/torch_ddp_example.py b/tutorials/torch_ddp_example.py index 15c3df6b..58b66d3d 100644 --- a/tutorials/torch_ddp_example.py +++ b/tutorials/torch_ddp_example.py @@ -383,17 +383,9 @@ def ind_loss( comm = MPI.COMM_WORLD if comm.rank == 0: # Download data at the top, then we don't need to later. - train_loader = DataLoader( - dataset=MNIST( - download=True, root=".", transform=None, train=True - ), # Use MNIST training dataset. - batch_size=2, # Batch size - num_workers=1, - pin_memory=True, - persistent_workers=True, - shuffle=True, # Shuffle data. - ) - del train_loader + dataset = MNIST(download=True, root=".", transform=None, train=True) + dataset = MNIST(download=True, root=".", transform=None, train=False) + del dataset comm.Barrier() pop_size = 2 * comm.size # Breeding population size limits = {