Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple example for how to use torch_xla #7048

Merged
merged 5 commits into from
May 14, 2024
Merged

Conversation

JackCaoG
Copy link
Collaborator

TODO:

  1. add readme
  2. add these to TPU tests

@JackCaoG JackCaoG requested a review from will-cromar May 11, 2024 00:35
@JackCaoG JackCaoG changed the title Jack cao g/example Add simple example for how to use torch_xla May 11, 2024
Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting this PR. Our various resnet examples were getting unwieldy, and I like that this is more straightforward.

My only concern here is that the TrainResNetBase might be too unwieldy. It sort of looks like a stripped down accelerate wrapper or LightningModule. IMO these examples should not necessarily focus on reusable code, but more on showcasing the order of execution in a training loop in a really clear way. Flax's examples do this well: https://github.com/google/flax/blob/main/examples/README.md (they have more model definition than we do because the ecosystem doesn't have an equivalent of torch{vision,text,audio})

Ideally we should be able to strip this down to a really simple function like train_resnet and different variations do different setup/teardown. I do see that there is at least one case where a child method accesses some common internal state (run_optimizer needs optimizer), so that may not be possible here.

self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
self.loss_fn = nn.CrossEntropyLoss()

def run_optimizer(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the only reason we need a class here to override optimizer.step here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea


class TrainResNetBase():

def __init__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this stays a class, use a dataclass here so we can update the batch size etc without having to write out the whole constructor.

Also, IMO, Python's constructors are ugly and I like to hide them as much as possible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't know how dataclass works haha. If you don't mind, can you starting a pr after this one to fix it?

self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(_train_update, args=(step, loss, tracker, epoch))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good opportunity to be the change we want to see in the world and use xla.sync() instead. There's no performance reason to use a step closure here, and I think with xla.sync() is clearer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    with xla.sync():
        self.optimizer.zero_grad()
        output = self.model(data)
        loss = self.loss_fn(output, target)
        loss.backward()
        self.run_optimizer()
    
    tracker.add(self.batch_size)
    if step % 10 == 0: print(...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, is sync change already merged?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually..hmm, if we want to do that we might want a way to disable mark_step from the loader, otherwise it will be a bit confusing when they turn on PT_XLA_DEBUG..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't worry about novice users enabling our hidden debug options 😛

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me think about this in a follow up pr. Calling mark_step is not free, one of the optimizations in inference is to remove any additional mark_step that's not required. If we were to add sync here, we want to somehow disable the mark_step in the loader.

examples/train_resnet_base.py Outdated Show resolved Hide resolved
examples/train_resnet_base.py Outdated Show resolved Hide resolved
for epoch in range(1, self.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
train_loop_fn(self.train_device_loader, epoch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just inline the train loop function here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm dynamo works like torch.compile(trainer, backend="openxla"). For normal case I can inline them.

@@ -0,0 +1,72 @@
from torch_xla import runtime as xr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this file to torch_xla/examples so it's included in our package? Then we can run examples like python -m torch_xla.examples.train_resnet_ddp directly from the installed package.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to include them as part of our package through?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would save users having to clone the correct branch off of github to run the examples. Since that's our typical smoke test to see if an environment is completely broken, I think it's worthwhile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I can move.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wait... we depends on torchvision in this example... I don't want to force that dependency in our package.. Let's pause this move until we figure out how to deal with that..

@JackCaoG
Copy link
Collaborator Author

My thinking is that I want to make it really clear how simple it is to enable some features (like xla_ddp, profiling). Having them in a separate file makes it very clear what code user needs to add to enable those features.

@JackCaoG JackCaoG requested a review from will-cromar May 13, 2024 23:37
@JackCaoG JackCaoG merged commit ae63cd1 into master May 14, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants