-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add simple example for how to use torch_xla #7048
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from torch_xla import runtime as xr | ||
import torch_xla.utils.utils as xu | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.parallel_loader as pl | ||
|
||
import os | ||
import time | ||
|
||
import torch | ||
import torchvision | ||
import torch.optim as optim | ||
import torch.nn as nn | ||
|
||
time.ctime() | ||
JackCaoG marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _train_update(step, loss, tracker, epoch): | ||
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}') | ||
|
||
|
||
class TrainResNetBase(): | ||
|
||
def __init__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this stays a class, use a Also, IMO, Python's constructors are ugly and I like to hide them as much as possible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually don't know how dataclass works haha. If you don't mind, can you starting a pr after this one to fix it? |
||
img_dim = 224 | ||
self.batch_size = 128 | ||
self.num_steps = 300 | ||
self.num_epochs = 1 | ||
train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. | ||
train_loader = xu.SampleGenerator( | ||
data=(torch.zeros(self.batch_size, 3, img_dim, img_dim), | ||
torch.zeros(self.batch_size, dtype=torch.int64)), | ||
sample_count=train_dataset_len // self.batch_size // | ||
xm.xrt_world_size()) | ||
|
||
self.device = xm.xla_device() | ||
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device) | ||
self.model = torchvision.models.resnet50().to(self.device) | ||
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) | ||
self.loss_fn = nn.CrossEntropyLoss() | ||
|
||
def run_optimizer(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the only reason we need a class here to override There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea |
||
self.optimizer.step() | ||
|
||
def start_training(self): | ||
|
||
def train_loop_fn(loader, epoch): | ||
tracker = xm.RateTracker() | ||
self.model.train() | ||
for step, (data, target) in enumerate(loader): | ||
self.optimizer.zero_grad() | ||
output = self.model(data) | ||
loss = self.loss_fn(output, target) | ||
loss.backward() | ||
self.run_optimizer() | ||
tracker.add(self.batch_size) | ||
if step % 10 == 0: | ||
xm.add_step_closure(_train_update, args=(step, loss, tracker, epoch)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good opportunity to be the change we want to see in the world and use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually..hmm, if we want to do that we might want a way to disable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't worry about novice users enabling our hidden debug options 😛 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me think about this in a follow up pr. Calling |
||
if self.num_steps == step: | ||
JackCaoG marked this conversation as resolved.
Show resolved
Hide resolved
|
||
break | ||
|
||
for epoch in range(1, self.num_epochs + 1): | ||
xm.master_print('Epoch {} train begin {}'.format( | ||
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y'))) | ||
train_loop_fn(self.train_device_loader, epoch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you just inline the train loop function here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm dynamo works like |
||
xm.master_print('Epoch {} train end {}'.format( | ||
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y'))) | ||
xm.wait_device_ops() | ||
|
||
|
||
if __name__ == '__main__': | ||
base = TrainResNetBase() | ||
base.start_training() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from train_resnet_base import TrainResNetBase | ||
import torch.distributed as dist | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
import torch.optim as optim | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
|
||
|
||
class TrainResNetDDP(TrainResNetBase): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
dist.init_process_group('xla', init_method='xla://') | ||
self.model = DDP( | ||
self.model, gradient_as_bucket_view=True, broadcast_buffers=False) | ||
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) | ||
|
||
|
||
def _mp_fn(index): | ||
ddp = TrainResNetDDP() | ||
ddp.start_training() | ||
|
||
|
||
if __name__ == '__main__': | ||
xmp.spawn(_mp_fn, args=()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from train_resnet_base import TrainResNetBase | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
import torch_xla.core.xla_model as xm | ||
|
||
|
||
class TrainResNetXLADDP(TrainResNetBase): | ||
|
||
def run_optimizer(self): | ||
xm.optimizer_step(self.optimizer) | ||
|
||
|
||
def _mp_fn(index): | ||
xla_ddp = TrainResNetXLADDP() | ||
xla_ddp.start_training() | ||
|
||
|
||
if __name__ == '__main__': | ||
xmp.spawn(_mp_fn, args=()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this file to
torch_xla/examples
so it's included in our package? Then we can run examples likepython -m torch_xla.examples.train_resnet_ddp
directly from the installed package.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to include them as part of our package through?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would save users having to clone the correct branch off of github to run the examples. Since that's our typical smoke test to see if an environment is completely broken, I think it's worthwhile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure I can move.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wait... we depends on torchvision in this example... I don't want to force that dependency in our package.. Let's pause this move until we figure out how to deal with that..