Skip to content

Commit

Permalink
clean up comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw92 committed May 8, 2024
1 parent 5fa49b2 commit cb642d0
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions tutorials/torch_ddp_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Toy example for HP optimization / NAS in Propulate, using a simple CNN trained on MNIST in a data-parallel fashion.
This script was tested on two compute nodes with 4 GPUs each. Note that you need to adapt ``GPUS_PER_NODE`` in l. 25.
This script was tested on two compute nodes with 4 GPUs each. Note that you need to adapt ``GPUS_PER_NODE`` in l. 29.
"""

import datetime as dt
Expand Down Expand Up @@ -135,7 +135,7 @@ def get_data_loaders(
val_dataset = MNIST(download=False, root=".", transform=data_transform, train=False)
if (
subgroup_comm.size > 1
): # need to make the samplers use the torch world to distributed data
): # Make the samplers use the torch world to distribute data
train_sampler = datadist.DistributedSampler(train_dataset)
val_sampler = datadist.DistributedSampler(val_dataset)
else:
Expand All @@ -150,7 +150,7 @@ def get_data_loaders(
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
shuffle=(train_sampler is None), # Shuffle data.
shuffle=(train_sampler is None), # Shuffle data only if no sampler is provided.
sampler=train_sampler,
)
val_loader = DataLoader(
Expand Down Expand Up @@ -290,7 +290,7 @@ def ind_loss(
conv_layers = params["conv_layers"] # Number of convolutional layers
activation = params["activation"] # Activation function
lr = params["lr"] # Learning rate
gamma = params["gamma"]
gamma = params["gamma"] # Learning rate reduction factor

epochs = 20

Expand All @@ -302,9 +302,8 @@ def ind_loss(
activation = activations[activation] # Get activation function.
loss_fn = torch.nn.NLLLoss()

model = Net(conv_layers, activation)
# Set up neural network with specified hyperparameters.
# model.best_accuracy = 0.0 # Initialize the model's best validation accuracy.
model = Net(conv_layers, activation)

train_loader, val_loader = get_data_loaders(
batch_size=8, subgroup_comm=subgroup_comm
Expand Down Expand Up @@ -428,7 +427,7 @@ def ind_loss(
migration_probability=config.migration_probability, # Migration probability
pollination=config.pollination, # Whether to use pollination or migration
checkpoint_path=config.checkpoint, # Checkpoint path
# ----- SPECIFIC FOR MULTI-RANK UCS ----
# ----- SPECIFIC FOR MULTI-RANK UCS -----
ranks_per_worker=2, # Number of ranks per (multi rank) worker
)

Expand Down

0 comments on commit cb642d0

Please sign in to comment.