From e60c2b9a55a8eb4baf1a8a3743de8de6eda9a70b Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Tue, 14 May 2024 21:42:48 +0000 Subject: [PATCH] add amp example --- examples/train_resnet_amp.py | 35 ++++++++++++++++++++++++++++++++ examples/train_resnet_base.py | 38 +++++++++++++++++------------------ 2 files changed, 54 insertions(+), 19 deletions(-) create mode 100644 examples/train_resnet_amp.py diff --git a/examples/train_resnet_amp.py b/examples/train_resnet_amp.py new file mode 100644 index 00000000000..ae541705d71 --- /dev/null +++ b/examples/train_resnet_amp.py @@ -0,0 +1,35 @@ +from train_resnet_base import TrainResNetBase + +import itertools + +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.core.xla_model as xm +from torch_xla.amp import autocast + + +# For more details check https://github.com/pytorch/xla/blob/master/docs/amp.md +class TrainResNetXLAAMP(TrainResNetBase): + + def train_loop_fn(self, loader, epoch): + tracker = xm.RateTracker() + self.model.train() + loader = itertools.islice(loader, self.num_steps) + for step, (data, target) in enumerate(loader): + self.optimizer.zero_grad() + # Enables autocasting for the forward pass + with autocast(xm.xla_device()): + output = self.model(data) + loss = self.loss_fn(output, target) + # TPU amp uses bf16 hence gradient scaling is not necessary. If runnign with XLA:GPU + # check https://github.com/pytorch/xla/blob/master/docs/amp.md#amp-for-xlagpu. + loss.backward() + self.run_optimizer() + tracker.add(self.batch_size) + if step % 10 == 0: + xm.add_step_closure( + self._train_update, args=(step, loss, tracker, epoch)) + + +if __name__ == '__main__': + xla_amp = TrainResNetXLAAMP() + xla_amp.start_training() diff --git a/examples/train_resnet_base.py b/examples/train_resnet_base.py index e01c9828ba6..5b5cdb92d69 100644 --- a/examples/train_resnet_base.py +++ b/examples/train_resnet_base.py @@ -13,10 +13,6 @@ import torch.nn as nn -def _train_update(step, loss, tracker, epoch): - print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}') - - class TrainResNetBase(): def __init__(self): @@ -37,29 +33,33 @@ def __init__(self): self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) self.loss_fn = nn.CrossEntropyLoss() + def _train_update(self, step, loss, tracker, epoch): + print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}') + def run_optimizer(self): self.optimizer.step() - def start_training(self): + def train_loop_fn(self, loader, epoch): + tracker = xm.RateTracker() + self.model.train() + loader = itertools.islice(loader, self.num_steps) + for step, (data, target) in enumerate(loader): + self.optimizer.zero_grad() + output = self.model(data) + loss = self.loss_fn(output, target) + loss.backward() + self.run_optimizer() + tracker.add(self.batch_size) + if step % 10 == 0: + xm.add_step_closure( + self._train_update, args=(step, loss, tracker, epoch)) - def train_loop_fn(loader, epoch): - tracker = xm.RateTracker() - self.model.train() - loader = itertools.islice(loader, self.num_steps) - for step, (data, target) in enumerate(loader): - self.optimizer.zero_grad() - output = self.model(data) - loss = self.loss_fn(output, target) - loss.backward() - self.run_optimizer() - tracker.add(self.batch_size) - if step % 10 == 0: - xm.add_step_closure(_train_update, args=(step, loss, tracker, epoch)) + def start_training(self): for epoch in range(1, self.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, time.strftime('%l:%M%p %Z on %b %d, %Y'))) - train_loop_fn(self.train_device_loader, epoch) + self.train_loop_fn(self.train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format( epoch, time.strftime('%l:%M%p %Z on %b %d, %Y'))) xm.wait_device_ops()