-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add simple example for how to use torch_xla #7048
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
## Overview | ||
This repo aims to provide some basic examples of how to run an existing pytorch model with PyTorch/XLA. train_resnet_base.py is a minimal trainer to run ResNet50 with fake data on a single device. Other examples will import the train_resnet_base and demonstrate how to enable different features(distributed training, profiling, dynamo etc) on PyTorch/XLA.The objective of this repository is to offer fundamental examples of executing an existing PyTorch model utilizing PyTorch/XLA. train_resnet_base.py acts as a bare-bones trainer for running ResNet50 with simulated data on an individual device. Additional examples will import train_resnet_base and illustrate how to activate various features (e.g., distributed training, profiling, dynamo) on PyTorch/XLA. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from torch_xla import runtime as xr | ||
import torch_xla.utils.utils as xu | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.parallel_loader as pl | ||
|
||
import time | ||
import itertools | ||
|
||
import torch | ||
import torch_xla | ||
import torchvision | ||
import torch.optim as optim | ||
import torch.nn as nn | ||
|
||
|
||
def _train_update(step, loss, tracker, epoch): | ||
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}') | ||
|
||
|
||
class TrainResNetBase(): | ||
|
||
def __init__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this stays a class, use a Also, IMO, Python's constructors are ugly and I like to hide them as much as possible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually don't know how dataclass works haha. If you don't mind, can you starting a pr after this one to fix it? |
||
img_dim = 224 | ||
self.batch_size = 128 | ||
self.num_steps = 300 | ||
self.num_epochs = 1 | ||
train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. | ||
# For the purpose of this example, we are going to use fake data. | ||
train_loader = xu.SampleGenerator( | ||
data=(torch.zeros(self.batch_size, 3, img_dim, img_dim), | ||
torch.zeros(self.batch_size, dtype=torch.int64)), | ||
sample_count=train_dataset_len // self.batch_size // xr.world_size()) | ||
|
||
self.device = torch_xla.device() | ||
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device) | ||
self.model = torchvision.models.resnet50().to(self.device) | ||
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) | ||
self.loss_fn = nn.CrossEntropyLoss() | ||
|
||
def run_optimizer(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the only reason we need a class here to override There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea |
||
self.optimizer.step() | ||
|
||
def start_training(self): | ||
|
||
def train_loop_fn(loader, epoch): | ||
tracker = xm.RateTracker() | ||
self.model.train() | ||
loader = itertools.islice(loader, self.num_steps) | ||
for step, (data, target) in enumerate(loader): | ||
self.optimizer.zero_grad() | ||
output = self.model(data) | ||
loss = self.loss_fn(output, target) | ||
loss.backward() | ||
self.run_optimizer() | ||
tracker.add(self.batch_size) | ||
if step % 10 == 0: | ||
xm.add_step_closure(_train_update, args=(step, loss, tracker, epoch)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good opportunity to be the change we want to see in the world and use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually..hmm, if we want to do that we might want a way to disable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't worry about novice users enabling our hidden debug options 😛 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me think about this in a follow up pr. Calling |
||
|
||
for epoch in range(1, self.num_epochs + 1): | ||
xm.master_print('Epoch {} train begin {}'.format( | ||
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y'))) | ||
train_loop_fn(self.train_device_loader, epoch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you just inline the train loop function here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm dynamo works like |
||
xm.master_print('Epoch {} train end {}'.format( | ||
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y'))) | ||
xm.wait_device_ops() | ||
|
||
|
||
if __name__ == '__main__': | ||
base = TrainResNetBase() | ||
base.start_training() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from train_resnet_base import TrainResNetBase | ||
import torch.distributed as dist | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
import torch.optim as optim | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
|
||
|
||
class TrainResNetDDP(TrainResNetBase): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
dist.init_process_group('xla', init_method='xla://') | ||
self.model = DDP( | ||
self.model, gradient_as_bucket_view=True, broadcast_buffers=False) | ||
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) | ||
|
||
|
||
def _mp_fn(index): | ||
ddp = TrainResNetDDP() | ||
ddp.start_training() | ||
|
||
|
||
if __name__ == '__main__': | ||
xmp.spawn(_mp_fn, args=()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import os | ||
|
||
from train_resnet_base import TrainResNetBase | ||
import torch_xla.debug.profiler as xp | ||
|
||
# check https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#environment-variables | ||
os.environ["XLA_IR_DEBUG"] = "1" | ||
os.environ["XLA_HLO_DEBUG"] = "1" | ||
|
||
if __name__ == '__main__': | ||
base = TrainResNetBase() | ||
profile_port = 9012 | ||
profile_logdir = "/tmp/profile/" | ||
duration_ms = 30000 | ||
assert os.path.exists(profile_logdir) | ||
server = xp.start_server(profile_port) | ||
# Ideally you want to start the profile tracing after the initial compilation, for example | ||
# at step 5. | ||
xp.trace_detached( | ||
f'localhost:{profile_port}', profile_logdir, duration_ms=duration_ms) | ||
base.start_training() | ||
# You can view the profile at tensorboard by | ||
# 1. pip install tensorflow tensorboard-plugin-profile | ||
# 2. tensorboard --logdir /tmp/profile/ --port 6006 | ||
# For more detail plase take a look at https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from train_resnet_base import TrainResNetBase | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
import torch_xla.core.xla_model as xm | ||
|
||
|
||
class TrainResNetXLADDP(TrainResNetBase): | ||
|
||
def run_optimizer(self): | ||
xm.optimizer_step(self.optimizer) | ||
|
||
|
||
def _mp_fn(index): | ||
xla_ddp = TrainResNetXLADDP() | ||
xla_ddp.start_training() | ||
|
||
|
||
if __name__ == '__main__': | ||
xmp.spawn(_mp_fn, args=()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this file to
torch_xla/examples
so it's included in our package? Then we can run examples likepython -m torch_xla.examples.train_resnet_ddp
directly from the installed package.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to include them as part of our package through?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would save users having to clone the correct branch off of github to run the examples. Since that's our typical smoke test to see if an environment is completely broken, I think it's worthwhile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure I can move.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wait... we depends on torchvision in this example... I don't want to force that dependency in our package.. Let's pause this move until we figure out how to deal with that..