Skip to content

Commit

Permalink
Add examples for training (#6929)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Apr 17, 2024
1 parent aee1d37 commit 781ee93
Show file tree
Hide file tree
Showing 10 changed files with 819 additions and 3 deletions.
96 changes: 95 additions & 1 deletion experimental/torch_xla2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,98 @@ pip install -e .
```bash
pip install -r test_requirements.txt
pytest test
```
```


## Run a model

Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model
it can be in theory any instance of `torch.nn.Module`.

```python

import torch_xla2
from torch import nn

class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

m = MyModel()

# Execute this model using torch
inputs = (torch.randn(3, 3, 28, 28), )
print(m(*inputs))
```

This model `m` contains 2 parts: the weights that is stored inside of the model
and it's submodules (`nn.Linear`).

To execute this model with `torch_xla2`; we need to move the tensors involved in compute
to `XLA` devices. This can be accomplished with `torch_xla2.tensor.move_to_device`.

We need move both the weights and the input to xla devices:

```python
from torch.utils import _pytree as pytree
from torch_xla2.tensor import move_to_device

inputs = move_to_device(inputs)
new_state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, m.state_dict())
m.load_state_dict(new_state_dict, assign=True)

res = m(*inputs)

print(type(res)) # outputs XLATensor2
```

### Executing with jax.jit

The above script will execute the model using eager mode Jax as backend. This
does allow executing torch models on TPU, but is often slower than what we can
achieve with `jax.jit`.

`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
and returns jax array) into the same function, but faster.

We have made the `jax_jit` decorator that would accomplish the same with functions
that takes and returns `torch.Tensor`. To use this, the first step is to create
a functional version of this model: this means the parameters should be passed in
as input instead of being attributes on class:


```python

def model_func(param, inputs):
return torch.func.functional_call(m, param, inputs)

```
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
from PyTorch to replace the model
weights with `param`, then call the model. This is equivalent to:

```python
def model_func(param, inputs):
m.load_state_dict(param)
return m(*inputs)
```

Now, we can apply `jax_jit`

```python
from torch_xla2.extra import jax_jit
model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))
```


115 changes: 115 additions & 0 deletions experimental/torch_xla2/examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
## Intro

This readme will have a subsection for every example *.py file.

Please follow the instructions in [README.md](../README.md) to install torch_xla2,
then install requirements for all of the examples with

```bash
pip install -r requirements.txt
```



## basic_training.py

This file constructed by first copy & paste code fragments from this pytorch training tutorial:
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

Then adding few lines of code that serves the purpose of moving `torch.Tensor` into
`XLA devices`.

Example:

```python
state_dict = pytree.tree_map_only(torch.Tensor,
torch_xla2.tensor.move_to_device, state_dict)
```

This fragment moves the state_dict to XLA devices; then the state_dict is passed
back to model via `load_state_dict`.

Then, you can train the model. This shows what is minimum to train a model on XLA
devices. The perf is not as good because we didn't use `jax.jit`, this is intentional
as it is meant to showcase the minimum code change.

Example run:
```bash
(xla2) hanq-macbookpro:examples hanq$ python basic_training.py
Training set has 60000 instances
Validation set has 10000 instances
Bag Dress Sneaker T-shirt/top
tensor([[0.8820, 0.3807, 0.3010, 0.9266, 0.7253, 0.9265, 0.0688, 0.4567, 0.7035,
0.2279],
[0.3253, 0.1558, 0.1274, 0.2776, 0.2590, 0.4169, 0.1881, 0.7423, 0.4561,
0.5985],
[0.5067, 0.4514, 0.9758, 0.6088, 0.7438, 0.6811, 0.9609, 0.3572, 0.4504,
0.8738],
[0.1850, 0.1217, 0.8551, 0.2120, 0.9902, 0.7623, 0.1658, 0.6980, 0.3086,
0.5709]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.325265645980835
EPOCH 1:
batch 1000 loss: 1.041275198560208
batch 2000 loss: 0.6450189483696595
batch 3000 loss: 0.5793989677671343
batch 4000 loss: 0.5170258888280951
batch 5000 loss: 0.4920090722264722
batch 6000 loss: 0.48910293977567926
batch 7000 loss: 0.48058812761632724
batch 8000 loss: 0.47159107415075413
batch 9000 loss: 0.4712311488997657
batch 10000 loss: 0.4675815168160479
batch 11000 loss: 0.43210567891132085
batch 12000 loss: 0.445208148030797
batch 13000 loss: 0.4119230824254337
batch 14000 loss: 0.4190662656680215
batch 15000 loss: 0.4094535468676477
LOSS train 0.4094535468676477 valid XLA
```

## basic_training_jax.py

This file constructed by first copy & paste code fragments from this pytorch training tutorial:
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

Then replacing torch optimizer with `optax` optimizer; and use `jax.grad` for
gradient instead of `torch.Tensor.backward()`.

Then, you can train the model using jax ecosystem's training loop. This is meant to
showcase how easy is to integrate with Jax.

Example run:
```bash
(xla2) hanq-macbookpro:examples hanq$ python basic_training_jax.py
Training set has 60000 instances
Validation set has 10000 instances
Pullover Ankle Boot Pullover Ankle Boot
tensor([[0.5279, 0.8340, 0.3131, 0.8608, 0.3668, 0.6192, 0.7453, 0.3261, 0.8872,
0.1854],
[0.7414, 0.8309, 0.8127, 0.8866, 0.2475, 0.2664, 0.0327, 0.6918, 0.6010,
0.2766],
[0.3304, 0.9135, 0.2762, 0.6737, 0.0480, 0.6150, 0.5610, 0.5804, 0.9607,
0.6450],
[0.9464, 0.9439, 0.3122, 0.1814, 0.1194, 0.5012, 0.2058, 0.1170, 0.7377,
0.7453]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.4054245948791504
EPOCH 1:
batch 1000 loss: 1.0705260595591972
batch 2000 loss: 1.0997755021179327
batch 3000 loss: 1.0186579653513108
batch 4000 loss: 0.9090727646966116
batch 5000 loss: 0.8309370622411024
batch 6000 loss: 0.8702225417760783
batch 7000 loss: 0.8750176187023462
batch 8000 loss: 0.9652624803795453
batch 9000 loss: 0.8688667197711766
batch 10000 loss: 0.8021814124770199
batch 11000 loss: 0.8000540231048071
batch 12000 loss: 0.9150884484921057
batch 13000 loss: 0.819690621060171
batch 14000 loss: 0.8569030471532278
batch 15000 loss: 0.8740896808278603
LOSS train 0.8740896808278603 valid 2.3132264614105225
```
112 changes: 112 additions & 0 deletions experimental/torch_xla2/examples/_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import functools

import torch
from time import time
from diffusers import DiffusionPipeline
from torch.utils import _pytree as pytree


import torch_xla2
import torch_xla2.functions
from torch_xla2.extra import torch_view, jax_view

import jax
import torch.func


class CompiledModule:

def __init__(self, model):
weights = model.state_dict()
weights.update(model.named_parameters())
self._weights = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.move_to_device, weights)
self._model = model

self._func_jitted_torch = None #torch_view(func_mod_jitted)


def _maybe_move_tensor(self, tensor):
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torch_xla2.tensor.XLATensor2):
return torch_xla2.tensor.move_to_device(tensor)
return tensor

def _make_jitted(self, args, kwargs):
static = []
for i, a in enumerate(args):
if not isinstance(a, torch.Tensor):
static.append(i + 1) # weight is 0
static_argnames = []
for k, v in kwargs.items():
if not isinstance(v, torch.Tensor):
static_argnames.append(k)

def f(weights, *args, **kwargs):
weights, args, kwargs = torch_xla2.tensor.wrap((weights, args, kwargs))
with torch_xla2.functions.XLAFunctionMode(), torch_xla2.tensor.XLADispatchMode():
res = torch.func.functional_call(self._model, weights, args, kwargs)
if isinstance(res, tuple) and len(res) == 1:
res = res[0]
return torch_xla2.tensor.unwrap(res)

fjit = jax.jit(f, static_argnames=tuple(static_argnames))
return torch_view(fjit)


def forward(self, *args, **kwargs):
(args, kwargs) = pytree.tree_map(self._maybe_move_tensor, (args, kwargs))
if self._func_jitted_torch is None:
self._func_jitted_torch = self._make_jitted(args, kwargs)
return self._func_jitted_torch(
self._weights,
*args,
**kwargs
)

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def __getattr__(self, key):
return getattr(self._model, key)


def compile_pipe(pipe):
pipe.text_encoder = CompiledModule(pipe.text_encoder)
pipe.text_encoder_2 = CompiledModule(pipe.text_encoder_2)
pipe.unet = CompiledModule(pipe.unet)
pipe.vae = CompiledModule(pipe.vae)


def main():
pipe = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-0.9",
"stabilityai/stable-diffusion-xl-base-1.0",
use_safetensors=True,

)
compile_pipe(pipe)

global_bs = 10
inference_steps = 20
resol = 1024
prompts = ["a photo of an astronaut riding a horse on mars"] * global_bs
print(f'global batch size {global_bs}',
f'inference steps {inference_steps}',
f'Image resolution {resol}',
flush=True
)

iters = 5
for i in range(iters):
prompt = prompts
# print('per device prompts len',len(prompt))
# prompt = prompts[rank]
start = time()
image = pipe(prompt,
num_inference_steps=inference_steps,
height=resol,
width=resol).images[0]
print(f'Step {i} inference time {time()-start} sec', flush=True)


if __name__ == '__main__':
main()
Loading

0 comments on commit 781ee93

Please sign in to comment.