diff --git a/tutorials/torch_example.py b/tutorials/torch_example.py index 49a27b60..e671a52a 100755 --- a/tutorials/torch_example.py +++ b/tutorials/torch_example.py @@ -221,35 +221,17 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]: num_workers = NUM_WORKERS log.info(f"Use {num_workers} workers in dataloader.") - if MPI.COMM_WORLD.rank == 0: # Only root downloads data. - train_loader = DataLoader( - dataset=MNIST( - download=True, root=".", transform=data_transform, train=True - ), # Use MNIST training dataset. - batch_size=batch_size, # Batch size - num_workers=num_workers, - pin_memory=True, - persistent_workers=True, - shuffle=True, # Shuffle data. - ) - - # NOTE barrier only called, when dataset has not been downloaded yet - if not hasattr(get_data_loaders, "barrier_called"): - MPI.COMM_WORLD.Barrier() - - setattr(get_data_loaders, "barrier_called", True) - - if MPI.COMM_WORLD.rank != 0: - train_loader = DataLoader( - dataset=MNIST( - download=False, root=".", transform=data_transform, train=True - ), # Use MNIST training dataset. - batch_size=batch_size, # Batch size - num_workers=num_workers, - pin_memory=True, - persistent_workers=True, - shuffle=True, # Shuffle data. - ) + # Note that the MNIST dataset has already been downloaded before globally by rank 0 in the main part. + train_loader = DataLoader( + dataset=MNIST( + download=False, root=".", transform=data_transform, train=True + ), # Use MNIST training dataset. + batch_size=batch_size, # Batch size + num_workers=num_workers, + pin_memory=True, + persistent_workers=True, + shuffle=True, # Shuffle data. + ) val_loader = DataLoader( dataset=MNIST( download=False, root=".", transform=data_transform, train=False @@ -327,6 +309,11 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float: if __name__ == "__main__": comm = MPI.COMM_WORLD + if comm.rank == 0: # Download data at the top, then we don't need to later. + dataset = MNIST(download=True, root=".", transform=None, train=True) + dataset = MNIST(download=True, root=".", transform=None, train=False) + del dataset + comm.Barrier() num_generations = 10 # Number of generations pop_size = 2 * comm.size # Breeding population size limits = {