Skip to content

Commit

Permalink
fix: barebone_mnist + comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 17, 2024
1 parent 1001a49 commit 87b45ca
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions examples/barebone_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,26 @@


class MNISTStage(dml.Stage):
# The pre_stage method is called before the first epoch
# It's a good place to load the dataset, create the model, and set up the optimizer
def pre_stage(self):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
# The root_first context manager ensures the root process downloads the dataset before the other processes
with dml.root_first():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='data', train=True, download=dml.is_root(), transform=transform)
self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
self.train_loader = DataLoader(train_dataset, batch_size=32, sampler=self.train_sampler)

val_dataset = datasets.MNIST(root='data', train=False, download=dml.is_root(), transform=transform)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
self.val_loader = DataLoader(val_dataset, batch_size=32, sampler=val_sampler)

self.model = nn.Sequential(
# For distributed training, we need to shard our dataset across all processes
# Here we use the DistributedSampler to do this
self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
self.train_loader = DataLoader(train_dataset, batch_size=32, sampler=self.train_sampler)

val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
self.val_loader = DataLoader(val_dataset, batch_size=32, sampler=val_sampler)

# We create our model regularly...
model = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
Expand All @@ -32,22 +39,31 @@ def pre_stage(self):
nn.Flatten(),
nn.Linear(784, 10),
)
self.model = dml.wrap_ddp(self.model, self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)

# ...and then wrap it with dml.wrap_ddp to enable distributed training
self.model = dml.wrap_ddp(model, self.device)

# It's also important to scale the learning rate based on the number of GPUs, dml.scale_lr does this for us
# Otherwise, we wouldn't profit from the increased batch size
# In practice, you would likely want to combine this with a linear lr rampup during the very first steps as well
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=dml.scale_lr(1e-3))

self.loss = nn.CrossEntropyLoss()

# Finally, we add columns to the table to track the loss and accuracy
self.add_column('[Train] Loss', 'train/loss', color='green')
self.add_column('[Train] Acc.', 'train/accuracy', color='green')
self.add_column('[Val] Loss', 'val/loss', color='blue')
self.add_column('[Val] Acc.', 'val/accuracy', color='blue')

# The run_epoch method is called once per epoch
def run_epoch(self):
self._train_epoch()
self._val_epoch()

def _train_epoch(self):
self.model.train()
self.metric_prefix = 'train'
self.metric_prefix = 'train' # This is used to prefix the metrics in the table
self.train_sampler.set_epoch(self.current_epoch)

for img, target in self.train_loader:
Expand Down

0 comments on commit 87b45ca

Please sign in to comment.