-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add simple example for how to use torch_xla #7048
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for starting this PR. Our various resnet examples were getting unwieldy, and I like that this is more straightforward.
My only concern here is that the TrainResNetBase
might be too unwieldy. It sort of looks like a stripped down accelerate
wrapper or LightningModule
. IMO these examples should not necessarily focus on reusable code, but more on showcasing the order of execution in a training loop in a really clear way. Flax's examples do this well: https://github.com/google/flax/blob/main/examples/README.md (they have more model definition than we do because the ecosystem doesn't have an equivalent of torch{vision,text,audio}
)
Ideally we should be able to strip this down to a really simple function like train_resnet
and different variations do different setup/teardown. I do see that there is at least one case where a child method accesses some common internal state (run_optimizer
needs optimizer
), so that may not be possible here.
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) | ||
self.loss_fn = nn.CrossEntropyLoss() | ||
|
||
def run_optimizer(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the only reason we need a class here to override optimizer.step
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea
|
||
class TrainResNetBase(): | ||
|
||
def __init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this stays a class, use a dataclass
here so we can update the batch size etc without having to write out the whole constructor.
Also, IMO, Python's constructors are ugly and I like to hide them as much as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't know how dataclass works haha. If you don't mind, can you starting a pr after this one to fix it?
self.run_optimizer() | ||
tracker.add(self.batch_size) | ||
if step % 10 == 0: | ||
xm.add_step_closure(_train_update, args=(step, loss, tracker, epoch)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good opportunity to be the change we want to see in the world and use xla.sync()
instead. There's no performance reason to use a step closure here, and I think with xla.sync()
is clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with xla.sync():
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss_fn(output, target)
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0: print(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, is sync
change already merged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually..hmm, if we want to do that we might want a way to disable mark_step
from the loader, otherwise it will be a bit confusing when they turn on PT_XLA_DEBUG
..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't worry about novice users enabling our hidden debug options 😛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me think about this in a follow up pr. Calling mark_step
is not free, one of the optimizations in inference is to remove any additional mark_step
that's not required. If we were to add sync
here, we want to somehow disable the mark_step
in the loader.
for epoch in range(1, self.num_epochs + 1): | ||
xm.master_print('Epoch {} train begin {}'.format( | ||
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y'))) | ||
train_loop_fn(self.train_device_loader, epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just inline the train loop function here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm dynamo works like torch.compile(trainer, backend="openxla")
. For normal case I can inline them.
@@ -0,0 +1,72 @@ | |||
from torch_xla import runtime as xr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this file to torch_xla/examples
so it's included in our package? Then we can run examples like python -m torch_xla.examples.train_resnet_ddp
directly from the installed package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to include them as part of our package through?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would save users having to clone the correct branch off of github to run the examples. Since that's our typical smoke test to see if an environment is completely broken, I think it's worthwhile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure I can move.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wait... we depends on torchvision in this example... I don't want to force that dependency in our package.. Let's pause this move until we figure out how to deal with that..
My thinking is that I want to make it really clear how simple it is to enable some features (like xla_ddp, profiling). Having them in a separate file makes it very clear what code user needs to add to enable those features. |
TODO: