From b0276da4e7d71e5488bd94d505b61248ed88a1ec Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 26 Apr 2024 17:32:01 -0700 Subject: [PATCH] Add readme for call a model (lost due to merge) --- experimental/torch_xla2/README.md | 93 +++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/experimental/torch_xla2/README.md b/experimental/torch_xla2/README.md index fba08f404984..594d53808821 100644 --- a/experimental/torch_xla2/README.md +++ b/experimental/torch_xla2/README.md @@ -60,3 +60,96 @@ pip install -e .[tpu] -f https://storage.googleapis.com/libtpu-releases/index.ht pip install -r test-requirements.txt pytest test ``` + +## Run a model + +Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model +it can be in theory any instance of `torch.nn.Module`. + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(28 * 28, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = x.view(-1, 28 * 28) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + +m = MyModel() + +# Execute this model using torch +inputs = torch.randn(3, 3, 28, 28) +print(m(inputs)) +``` + +This model `m` contains 2 parts: the weights that is stored inside of the model +and it's submodules (`nn.Linear`). + +To execute this model with `torch_xla2`; we need to move the tensors involved in compute +to `XLA` devices. This can be accomplished with `torch_xla2.tensor.move_to_device`. + +We need move both the weights and the input to xla devices: + +```python +import torch_xla2 +from torch.utils import _pytree as pytree +from torch_xla2.tensor import move_to_device + +inputs = move_to_device(inputs) +new_state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, m.state_dict()) +m.load_state_dict(new_state_dict, assign=True) + +res = m(inputs) + +print(type(res)) # outputs XLATensor2 +``` + +### Executing with jax.jit + +The above script will execute the model using eager mode Jax as backend. This +does allow executing torch models on TPU, but is often slower than what we can +achieve with `jax.jit`. + +`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array +and returns jax array) into the same function, but faster. + +We have made the `jax_jit` decorator that would accomplish the same with functions +that takes and returns `torch.Tensor`. To use this, the first step is to create +a functional version of this model: this means the parameters should be passed in +as input instead of being attributes on class: + + +```python + +def model_func(param, inputs): + return torch.func.functional_call(m, param, inputs) + +``` +Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html) +from PyTorch to replace the model +weights with `param`, then call the model. This is equivalent to: + +```python +def model_func(param, inputs): + m.load_state_dict(param) + return m(*inputs) +``` + +Now, we can apply `jax_jit` + +```python +from torch_xla2.extra import jax_jit +model_func_jitted = jax_jit(model_func) +print(model_func_jitted(new_state_dict, inputs)) +```