Skip to content

Commit

Permalink
[torch_xla2] Simplify developer setup steps (#6905)
Browse files Browse the repository at this point in the history
Co-authored-by: qihqi <[email protected]>
  • Loading branch information
will-cromar and qihqi authored Apr 24, 2024
1 parent 6fd448d commit 89efd17
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 163 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ jobs:
shell: bash
working-directory: experimental/torch_xla2
run: |
pip install pytest absl-py jax[cpu] flatbuffers tensorflow
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install -r test_requirements.txt
pip install -e .
pip install -r test-requirements.txt
pip install -e .[cpu]
- name: Run tests
working-directory: experimental/torch_xla2
shell: bash
Expand Down
156 changes: 24 additions & 132 deletions experimental/torch_xla2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

Currently this is only source-installable. Requires Python version >= 3.10.

### NOTE:
### NOTE:

Please don't install torch-xla from instructions in
https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md .
In particular, the following are not needed:
Expand All @@ -18,153 +19,44 @@ TorchXLA2 and torch-xla have different installation instructions, please follow
the instructions below from scratch (fresh venv / conda environment.)


### 1. Install dependencies

#### 1.0 (optional) Make a virtualenv / conda env, and activate it.

```bash
conda create --name <your_name> python=3.10
conda activate <your_name>
```
Or,
```bash
python -m venv create my_venv
source my_venv/bin/activate
```

#### 1.1 Install torch CPU, even if your device has GPU or TPU:
### 1. Installing `torch_xla2`

```bash
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
```
#### 1.0 (recommended) Make a virtualenv / conda env

Or, follow official instructions in [pytorch.org](https://pytorch.org/get-started/locally/) to install for your OS.
If you are using VSCode, then [you can create a new environment from
UI](https://code.visualstudio.com/docs/python/environments). Select the
`dev-requirements.txt` when asked to install project dependencies.

#### 1.2 Install Jax for either GPU or TPU
Otherwise create a new environment from the command line.

If you are using Google Cloud TPU, then
```bash
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
# Option 1: venv
python -m venv create my_venv
source my_venv/bin/activate

If you are using a machine with NVidia GPU:
# Option 2: conda
conda create --name <your_name> python=3.10
conda activate <your_name>

```bash
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Either way, install the dev requirements.
pip install -r dev-requirements.txt
```

If you are using a CPU-only machine:
```bash
pip install --upgrade "jax[cpu]"
```
Note: `dev-requirements.txt` will install the CPU-only version of PyTorch.

Or, follow the official instructions in https://jax.readthedocs.io/en/latest/installation.html to install for your OS or Device.
#### 1.1 Install this package

#### 1.3 Install this package
Install `torch_xla2` from source for your platform:

```bash
pip install -e .
pip install -e .[cpu]
pip install -e .[cuda]
pip install -e .[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
```

#### 1.4 (optional) verify installation by running tests
#### 1.2 (optional) verify installation by running tests

```bash
pip install -r test_requirements.txt
pip install -r test-requirements.txt
pytest test
```


## Run a model

Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model
it can be in theory any instance of `torch.nn.Module`.

```python
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

m = MyModel()

# Execute this model using torch
inputs = torch.randn(3, 3, 28, 28)
print(m(inputs))
```

This model `m` contains 2 parts: the weights that is stored inside of the model
and it's submodules (`nn.Linear`).

To execute this model with `torch_xla2`; we need to move the tensors involved in compute
to `XLA` devices. This can be accomplished with `torch_xla2.tensor.move_to_device`.

We need move both the weights and the input to xla devices:

```python
import torch_xla2
from torch.utils import _pytree as pytree
from torch_xla2.tensor import move_to_device

inputs = move_to_device(inputs)
new_state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, m.state_dict())
m.load_state_dict(new_state_dict, assign=True)

res = m(inputs)

print(type(res)) # outputs XLATensor2
```

### Executing with jax.jit

The above script will execute the model using eager mode Jax as backend. This
does allow executing torch models on TPU, but is often slower than what we can
achieve with `jax.jit`.

`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
and returns jax array) into the same function, but faster.

We have made the `jax_jit` decorator that would accomplish the same with functions
that takes and returns `torch.Tensor`. To use this, the first step is to create
a functional version of this model: this means the parameters should be passed in
as input instead of being attributes on class:


```python

def model_func(param, inputs):
return torch.func.functional_call(m, param, inputs)

```
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
from PyTorch to replace the model
weights with `param`, then call the model. This is equivalent to:

```python
def model_func(param, inputs):
m.load_state_dict(param)
return m(*inputs)
```

Now, we can apply `jax_jit`

```python
from torch_xla2.extra import jax_jit
model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))
```


10 changes: 2 additions & 8 deletions experimental/torch_xla2/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
absl-py==2.0.0
flatbuffers==23.5.26
jax==0.4.23
jaxlib==0.4.23
pytest
tensorflow
-f https://download.pytorch.org/whl/torch
torch==2.2.1+cpu
immutabledict
sentencepiece
ruff~=0.3.5
29 changes: 15 additions & 14 deletions experimental/torch_xla2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@
requires = ["hatchling"]
build-backend = "hatchling.build"


[project]
version = "0.0.1"
name = "torch_xla2"
dependencies = [
"absl-py",
"flatbuffers",
"immutabledict",
"jax>=0.4.24",
"pytest",
"tensorflow",

# Note: Exclude these because otherwise on pip install .
# pip will install libs from pypi which is the GPU version
# of these libs.
# We most likely need CPU version of torch and TPU version of
# jax. So it's best for users to install them by hand
# See more at README.md
# "jax>=0.4.24",
# "jaxlib>=0.4.24",
# "torch",
"tensorflow-cpu",
# Developers should install `dev-requirements.txt` first
"torch>=2.2.1",
]

requires-python = ">=3.10"
license = {file = "LICENSE"}

[project.optional-dependencies]
cpu = ["jax[cpu]"]
# Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html`
tpu = ["jax[tpu]"]
cuda = ["jax[cuda12]"]

[tool.pytest.ini_options]
addopts="-n auto"

[tool.ruff]
line-length = 80
indent-width = 2
5 changes: 5 additions & 0 deletions experimental/torch_xla2/test-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-r dev-requirements.txt
pytest
pytest-xdist
sentencepiece
expecttest
5 changes: 0 additions & 5 deletions experimental/torch_xla2/test_requirements.txt

This file was deleted.

0 comments on commit 89efd17

Please sign in to comment.