Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple example for how to use torch_xla #7048

Merged
merged 5 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
## Overview
This repo aims to provide some basic examples of how to run an existing pytorch model with PyTorch/XLA. train_resnet_base.py is a minimal trainer to run ResNet50 with fake data on a single device. Other examples will import the train_resnet_base and demonstrate how to enable different features(distributed training, profiling, dynamo etc) on PyTorch/XLA.The objective of this repository is to offer fundamental examples of executing an existing PyTorch model utilizing PyTorch/XLA. train_resnet_base.py acts as a bare-bones trainer for running ResNet50 with simulated data on an individual device. Additional examples will import train_resnet_base and illustrate how to activate various features (e.g., distributed training, profiling, dynamo) on PyTorch/XLA.
70 changes: 70 additions & 0 deletions examples/train_resnet_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from torch_xla import runtime as xr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this file to torch_xla/examples so it's included in our package? Then we can run examples like python -m torch_xla.examples.train_resnet_ddp directly from the installed package.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to include them as part of our package through?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would save users having to clone the correct branch off of github to run the examples. Since that's our typical smoke test to see if an environment is completely broken, I think it's worthwhile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I can move.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wait... we depends on torchvision in this example... I don't want to force that dependency in our package.. Let's pause this move until we figure out how to deal with that..

import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

import time
import itertools

import torch
import torch_xla
import torchvision
import torch.optim as optim
import torch.nn as nn


def _train_update(step, loss, tracker, epoch):
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')


class TrainResNetBase():

def __init__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this stays a class, use a dataclass here so we can update the batch size etc without having to write out the whole constructor.

Also, IMO, Python's constructors are ugly and I like to hide them as much as possible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't know how dataclass works haha. If you don't mind, can you starting a pr after this one to fix it?

img_dim = 224
self.batch_size = 128
self.num_steps = 300
self.num_epochs = 1
train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
# For the purpose of this example, we are going to use fake data.
train_loader = xu.SampleGenerator(
data=(torch.zeros(self.batch_size, 3, img_dim, img_dim),
torch.zeros(self.batch_size, dtype=torch.int64)),
sample_count=train_dataset_len // self.batch_size // xr.world_size())

self.device = torch_xla.device()
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
self.model = torchvision.models.resnet50().to(self.device)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
self.loss_fn = nn.CrossEntropyLoss()

def run_optimizer(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the only reason we need a class here to override optimizer.step here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea

self.optimizer.step()

def start_training(self):

def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
self.model.train()
loader = itertools.islice(loader, self.num_steps)
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss_fn(output, target)
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(_train_update, args=(step, loss, tracker, epoch))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good opportunity to be the change we want to see in the world and use xla.sync() instead. There's no performance reason to use a step closure here, and I think with xla.sync() is clearer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    with xla.sync():
        self.optimizer.zero_grad()
        output = self.model(data)
        loss = self.loss_fn(output, target)
        loss.backward()
        self.run_optimizer()
    
    tracker.add(self.batch_size)
    if step % 10 == 0: print(...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, is sync change already merged?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually..hmm, if we want to do that we might want a way to disable mark_step from the loader, otherwise it will be a bit confusing when they turn on PT_XLA_DEBUG..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't worry about novice users enabling our hidden debug options 😛

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me think about this in a follow up pr. Calling mark_step is not free, one of the optimizations in inference is to remove any additional mark_step that's not required. If we were to add sync here, we want to somehow disable the mark_step in the loader.


for epoch in range(1, self.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
train_loop_fn(self.train_device_loader, epoch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just inline the train loop function here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm dynamo works like torch.compile(trainer, backend="openxla"). For normal case I can inline them.

xm.master_print('Epoch {} train end {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
xm.wait_device_ops()


if __name__ == '__main__':
base = TrainResNetBase()
base.start_training()
24 changes: 24 additions & 0 deletions examples/train_resnet_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from train_resnet_base import TrainResNetBase
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch_xla.distributed.xla_multiprocessing as xmp


class TrainResNetDDP(TrainResNetBase):

def __init__(self):
super().__init__()
dist.init_process_group('xla', init_method='xla://')
self.model = DDP(
self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)


def _mp_fn(index):
ddp = TrainResNetDDP()
ddp.start_training()


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
25 changes: 25 additions & 0 deletions examples/train_resnet_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os

from train_resnet_base import TrainResNetBase
import torch_xla.debug.profiler as xp

# check https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#environment-variables
os.environ["XLA_IR_DEBUG"] = "1"
os.environ["XLA_HLO_DEBUG"] = "1"

if __name__ == '__main__':
base = TrainResNetBase()
profile_port = 9012
profile_logdir = "/tmp/profile/"
duration_ms = 30000
assert os.path.exists(profile_logdir)
server = xp.start_server(profile_port)
# Ideally you want to start the profile tracing after the initial compilation, for example
# at step 5.
xp.trace_detached(
f'localhost:{profile_port}', profile_logdir, duration_ms=duration_ms)
base.start_training()
# You can view the profile at tensorboard by
# 1. pip install tensorflow tensorboard-plugin-profile
# 2. tensorboard --logdir /tmp/profile/ --port 6006
# For more detail plase take a look at https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm
18 changes: 18 additions & 0 deletions examples/train_resnet_xla_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from train_resnet_base import TrainResNetBase
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm


class TrainResNetXLADDP(TrainResNetBase):

def run_optimizer(self):
xm.optimizer_step(self.optimizer)


def _mp_fn(index):
xla_ddp = TrainResNetXLADDP()
xla_ddp.start_training()


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
Loading