Skip to content

Commit

Permalink
Add simple example for how to use torch_xla
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed May 11, 2024
1 parent c1b745e commit 433cdc1
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 0 deletions.
70 changes: 70 additions & 0 deletions examples/train_resnet_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from torch_xla import runtime as xr
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

import os
import time

import torch
import torchvision
import torch.optim as optim
import torch.nn as nn

time.ctime()

def _train_update(step, loss, tracker, epoch):
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')

class TrainResNetBase():
def __init__(self):
img_dim = 224
self.batch_size = 128
self.num_steps = 300
self.num_epochs=1
train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
train_loader = xu.SampleGenerator(
data=(torch.zeros(self.batch_size, 3, img_dim, img_dim),
torch.zeros(self.batch_size, dtype=torch.int64)),
sample_count=train_dataset_len // self.batch_size //
xm.xrt_world_size())

self.device = xm.xla_device()
self.train_device_loader = pl.MpDeviceLoader(
train_loader,
self.device)
self.model = torchvision.models.resnet50().to(self.device)
self.optimizer = optim.SGD(
self.model.parameters(),
weight_decay=1e-4)
self.loss_fn = nn.CrossEntropyLoss()

def run_optimizer(self):
self.optimizer.step()

def start_training(self):
def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
self.model.train()
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss_fn(output, target)
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(
_train_update, args=(step, loss, tracker, epoch))
if self.num_steps == step:
break

for epoch in range(1, self.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
train_loop_fn(self.train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
xm.wait_device_ops()

if __name__ == '__main__':
base = TrainResNetBase()
base.start_training()
22 changes: 22 additions & 0 deletions examples/train_resnet_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from train_resnet_base import TrainResNetBase
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch_xla.distributed.xla_multiprocessing as xmp


class TrainResNetDDP(TrainResNetBase):
def __init__(self):
super().__init__()
dist.init_process_group('xla', init_method='xla://')
self.model = DDP(self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
self.optimizer = optim.SGD(
self.model.parameters(),
weight_decay=1e-4)

def _mp_fn(index):
ddp = TrainResNetDDP()
ddp.start_training()

if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
14 changes: 14 additions & 0 deletions examples/train_resnet_xla_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from train_resnet_base import TrainResNetBase
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm

class TrainResNetXLADDP(TrainResNetBase):
def run_optimizer(self):
xm.optimizer_step(self.optimizer)

def _mp_fn(index):
xla_ddp = TrainResNetXLADDP()
xla_ddp.start_training()

if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())

0 comments on commit 433cdc1

Please sign in to comment.