Skip to content

Commit

Permalink
add amp example
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed May 14, 2024
1 parent f26c35c commit e60c2b9
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 19 deletions.
35 changes: 35 additions & 0 deletions examples/train_resnet_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from train_resnet_base import TrainResNetBase

import itertools

import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
from torch_xla.amp import autocast


# For more details check https://github.com/pytorch/xla/blob/master/docs/amp.md
class TrainResNetXLAAMP(TrainResNetBase):

def train_loop_fn(self, loader, epoch):
tracker = xm.RateTracker()
self.model.train()
loader = itertools.islice(loader, self.num_steps)
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
# Enables autocasting for the forward pass
with autocast(xm.xla_device()):
output = self.model(data)
loss = self.loss_fn(output, target)
# TPU amp uses bf16 hence gradient scaling is not necessary. If runnign with XLA:GPU
# check https://github.com/pytorch/xla/blob/master/docs/amp.md#amp-for-xlagpu.
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(
self._train_update, args=(step, loss, tracker, epoch))


if __name__ == '__main__':
xla_amp = TrainResNetXLAAMP()
xla_amp.start_training()
38 changes: 19 additions & 19 deletions examples/train_resnet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
import torch.nn as nn


def _train_update(step, loss, tracker, epoch):
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')


class TrainResNetBase():

def __init__(self):
Expand All @@ -37,29 +33,33 @@ def __init__(self):
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
self.loss_fn = nn.CrossEntropyLoss()

def _train_update(self, step, loss, tracker, epoch):
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')

def run_optimizer(self):
self.optimizer.step()

def start_training(self):
def train_loop_fn(self, loader, epoch):
tracker = xm.RateTracker()
self.model.train()
loader = itertools.islice(loader, self.num_steps)
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss_fn(output, target)
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(
self._train_update, args=(step, loss, tracker, epoch))

def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
self.model.train()
loader = itertools.islice(loader, self.num_steps)
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss_fn(output, target)
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(_train_update, args=(step, loss, tracker, epoch))
def start_training(self):

for epoch in range(1, self.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
train_loop_fn(self.train_device_loader, epoch)
self.train_loop_fn(self.train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
xm.wait_device_ops()
Expand Down

0 comments on commit e60c2b9

Please sign in to comment.