Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add readme to call a model (lost due to merge conflicts) #6986

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions experimental/torch_xla2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,96 @@ pip install -e .[tpu] -f https://storage.googleapis.com/libtpu-releases/index.ht
pip install -r test-requirements.txt
pytest test
```

## Run a model

Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model
it can be in theory any instance of `torch.nn.Module`.

```python
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

m = MyModel()

# Execute this model using torch
inputs = torch.randn(3, 3, 28, 28)
print(m(inputs))
```

This model `m` contains 2 parts: the weights that is stored inside of the model
and it's submodules (`nn.Linear`).

To execute this model with `torch_xla2`; we need to move the tensors involved in compute
to `XLA` devices. This can be accomplished with `torch_xla2.tensor.move_to_device`.

We need move both the weights and the input to xla devices:

```python
import torch_xla2
from torch.utils import _pytree as pytree
from torch_xla2.tensor import move_to_device

inputs = move_to_device(inputs)
new_state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, m.state_dict())
m.load_state_dict(new_state_dict, assign=True)

res = m(inputs)

print(type(res)) # outputs XLATensor2
```

### Executing with jax.jit

The above script will execute the model using eager mode Jax as backend. This
does allow executing torch models on TPU, but is often slower than what we can
achieve with `jax.jit`.

`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
and returns jax array) into the same function, but faster.

We have made the `jax_jit` decorator that would accomplish the same with functions
that takes and returns `torch.Tensor`. To use this, the first step is to create
a functional version of this model: this means the parameters should be passed in
as input instead of being attributes on class:


```python

def model_func(param, inputs):
return torch.func.functional_call(m, param, inputs)

```
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
from PyTorch to replace the model
weights with `param`, then call the model. This is equivalent to:

```python
def model_func(param, inputs):
m.load_state_dict(param)
return m(*inputs)
```

Now, we can apply `jax_jit`

```python
from torch_xla2.extra import jax_jit
model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))
```
Loading