Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple example for how to use torch_xla #7048

Merged
merged 5 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/train_resnet_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from torch_xla import runtime as xr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this file to torch_xla/examples so it's included in our package? Then we can run examples like python -m torch_xla.examples.train_resnet_ddp directly from the installed package.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to include them as part of our package through?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would save users having to clone the correct branch off of github to run the examples. Since that's our typical smoke test to see if an environment is completely broken, I think it's worthwhile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I can move.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wait... we depends on torchvision in this example... I don't want to force that dependency in our package.. Let's pause this move until we figure out how to deal with that..

import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

import os
import time

import torch
import torchvision
import torch.optim as optim
import torch.nn as nn

time.ctime()
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved


def _train_update(step, loss, tracker, epoch):
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')


class TrainResNetBase():

def __init__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this stays a class, use a dataclass here so we can update the batch size etc without having to write out the whole constructor.

Also, IMO, Python's constructors are ugly and I like to hide them as much as possible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't know how dataclass works haha. If you don't mind, can you starting a pr after this one to fix it?

img_dim = 224
self.batch_size = 128
self.num_steps = 300
self.num_epochs = 1
train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
train_loader = xu.SampleGenerator(
data=(torch.zeros(self.batch_size, 3, img_dim, img_dim),
torch.zeros(self.batch_size, dtype=torch.int64)),
sample_count=train_dataset_len // self.batch_size //
xm.xrt_world_size())

self.device = xm.xla_device()
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
self.model = torchvision.models.resnet50().to(self.device)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
self.loss_fn = nn.CrossEntropyLoss()

def run_optimizer(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the only reason we need a class here to override optimizer.step here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea

self.optimizer.step()

def start_training(self):

def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
self.model.train()
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss_fn(output, target)
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(_train_update, args=(step, loss, tracker, epoch))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good opportunity to be the change we want to see in the world and use xla.sync() instead. There's no performance reason to use a step closure here, and I think with xla.sync() is clearer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    with xla.sync():
        self.optimizer.zero_grad()
        output = self.model(data)
        loss = self.loss_fn(output, target)
        loss.backward()
        self.run_optimizer()
    
    tracker.add(self.batch_size)
    if step % 10 == 0: print(...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, is sync change already merged?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually..hmm, if we want to do that we might want a way to disable mark_step from the loader, otherwise it will be a bit confusing when they turn on PT_XLA_DEBUG..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't worry about novice users enabling our hidden debug options 😛

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me think about this in a follow up pr. Calling mark_step is not free, one of the optimizations in inference is to remove any additional mark_step that's not required. If we were to add sync here, we want to somehow disable the mark_step in the loader.

if self.num_steps == step:
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
break

for epoch in range(1, self.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
train_loop_fn(self.train_device_loader, epoch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just inline the train loop function here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm dynamo works like torch.compile(trainer, backend="openxla"). For normal case I can inline them.

xm.master_print('Epoch {} train end {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
xm.wait_device_ops()


if __name__ == '__main__':
base = TrainResNetBase()
base.start_training()
24 changes: 24 additions & 0 deletions examples/train_resnet_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from train_resnet_base import TrainResNetBase
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch_xla.distributed.xla_multiprocessing as xmp


class TrainResNetDDP(TrainResNetBase):

def __init__(self):
super().__init__()
dist.init_process_group('xla', init_method='xla://')
self.model = DDP(
self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)


def _mp_fn(index):
ddp = TrainResNetDDP()
ddp.start_training()


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
18 changes: 18 additions & 0 deletions examples/train_resnet_xla_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from train_resnet_base import TrainResNetBase
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm


class TrainResNetXLADDP(TrainResNetBase):

def run_optimizer(self):
xm.optimizer_step(self.optimizer)


def _mp_fn(index):
xla_ddp = TrainResNetXLADDP()
xla_ddp.start_training()


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
Loading