Skip to content

Commit

Permalink
introduce global initial dataloading in main
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw92 committed May 8, 2024
1 parent cb642d0 commit d27069c
Showing 1 changed file with 16 additions and 29 deletions.
45 changes: 16 additions & 29 deletions tutorials/torch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,35 +221,17 @@ def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]:
num_workers = NUM_WORKERS
log.info(f"Use {num_workers} workers in dataloader.")

if MPI.COMM_WORLD.rank == 0: # Only root downloads data.
train_loader = DataLoader(
dataset=MNIST(
download=True, root=".", transform=data_transform, train=True
), # Use MNIST training dataset.
batch_size=batch_size, # Batch size
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
shuffle=True, # Shuffle data.
)

# NOTE barrier only called, when dataset has not been downloaded yet
if not hasattr(get_data_loaders, "barrier_called"):
MPI.COMM_WORLD.Barrier()

setattr(get_data_loaders, "barrier_called", True)

if MPI.COMM_WORLD.rank != 0:
train_loader = DataLoader(
dataset=MNIST(
download=False, root=".", transform=data_transform, train=True
), # Use MNIST training dataset.
batch_size=batch_size, # Batch size
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
shuffle=True, # Shuffle data.
)
# Note that the MNIST dataset has already been downloaded before globally by rank 0 in the main part.
train_loader = DataLoader(
dataset=MNIST(
download=False, root=".", transform=data_transform, train=True
), # Use MNIST training dataset.
batch_size=batch_size, # Batch size
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
shuffle=True, # Shuffle data.
)
val_loader = DataLoader(
dataset=MNIST(
download=False, root=".", transform=data_transform, train=False
Expand Down Expand Up @@ -327,6 +309,11 @@ def ind_loss(params: Dict[str, Union[int, float, str]]) -> float:

if __name__ == "__main__":
comm = MPI.COMM_WORLD
if comm.rank == 0: # Download data at the top, then we don't need to later.
dataset = MNIST(download=True, root=".", transform=None, train=True)
dataset = MNIST(download=True, root=".", transform=None, train=False)
del dataset
comm.Barrier()
num_generations = 10 # Number of generations
pop_size = 2 * comm.size # Breeding population size
limits = {
Expand Down

0 comments on commit d27069c

Please sign in to comment.