Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs on how to integrate with dynamo #6459

Merged
merged 3 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ Then run `test/run_tests.sh` and `test/cpp/run_tests.sh` to verify the setup is
### Useful materials
1. [OP Lowering Guide](https://github.com/pytorch/xla/blob/master/OP_LOWERING_GUIDE.md)
2. [CODEGEN MIGRATION GUIDE](https://github.com/pytorch/xla/blob/master/CODEGEN_MIGRATION_GUIDE.md)
3. [Dynamo Integration Guide](https://github.com/pytorch/xla/blob/master/docs/dynamo.md)

### Sharp Edges

Expand Down
4 changes: 4 additions & 0 deletions TROUBLESHOOTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ report sent to us if you have it.

You can enable the PyTorch/XLA debugging tool by setting `PT_XLA_DEBUG=1`, which provides a couple useful debugging features.

## PyTorch/XLA + Dynamo Debugging Tool

You can enable the PyTorch/XLA + Dynamo debugging tool by setting `XLA_DYNAMO_DEBUG=1`, which provides a couple useful debugging features.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Nit-picking)
I see that this is copied from above, but without mentioning those "couple [of] useful debugging features", the statement is redundant.
I suggest we drop everything after the comma (, which provides...).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I can add a suggestion, let me do that.

changm marked this conversation as resolved.
Show resolved Hide resolved

### Perform A Auto-Metrics Analysis

The debugging tool will analyze the metrics report and provide a summary. Some example output would be
Expand Down
27 changes: 25 additions & 2 deletions docs/dynamo.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,32 @@
## TorchDynamo(torch.compile) integration in PyTorch XLA

[TorchDynamo](https://pytorch.org/docs/stable/dynamo/index.html) is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. It provides a clean API for compiler backends to hook in and its biggest feature is to dynamically modify Python bytecode right before it is executed. In the pytorch/xla 2.0 release, PyTorch/XLA provided an experimental backend for the TorchDynamo for both inference and training.
[TorchDynamo](https://pytorch.org/docs/stable/torch.compiler.html) is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. It provides a clean API for compiler backends to hook in and its biggest feature is to dynamically modify Python bytecode right before it is executed. In the pytorch/xla 2.0 release, PyTorch/XLA provided an experimental backend for the TorchDynamo for both inference and training.

The way that XLA bridge works is that Dynamo will provide a TorchFX graph when it recognizes a model pattern and PyTorch/XLA will use existing Lazy Tensor technology to compile the FX graph and return the compiled function.

### Integration

Support for PyTorch/XLA and Dynamo currently exists by adding the `backend='openxla'` argument to `torch.compile`. For example:

```
import torch
import torch_xla.core.xla_model as xm

def add(a, b):
a_xla = a.to(xm.xla_device())
b_xla = b.to(xm.xla_device())
return a_xla + b_xla

compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))
```

Currently there are two different backends, that eventually will be merged into a single 'openxla' backend:

* `backend='openxla'` - Useful for training.
* `backend='openxla_eval'` - Useful for inference.


### Inference
Here is a small code example of running resnet18 with `torch.compile`

Expand Down Expand Up @@ -40,7 +63,7 @@ timm_vision_transformer | 3.52
geomean | 3.04

Note
1. User will likely see better inference perfomrance by putting the inference execution in a `torch.no_grad` context. `openxla` is a `aot-autograd` backend of `torch.compile`. `Aot-autograd` will attempt to save some states for potential backward. `torch.no_grad` will help `aot-autograd` understand that it is being executed in a inference context.
1. User will likely see better inference performance by putting the inference execution in a `torch.no_grad` context. `openxla` is a `aot-autograd` backend of `torch.compile`. `Aot-autograd` will attempt to save some state for potential backward. `torch.no_grad` will help `aot-autograd` understand that it is being executed in a inference context.
changm marked this conversation as resolved.
Show resolved Hide resolved
2. User can also use the `openxla_eval` backend directly without `torch.no_grad`, since `openxla_eval` is not an `aot-autograd` backend and only works for inference.

### Training
Expand Down
Loading