Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch_xla2] Simplify developer setup steps #6905

Merged
merged 5 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ jobs:
shell: bash
working-directory: experimental/torch_xla2
run: |
pip install pytest absl-py jax[cpu] flatbuffers tensorflow
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install -r test_requirements.txt
pip install -e .
pip install -r test-requirements.txt
pip install -e .[cpu]
- name: Run tests
working-directory: experimental/torch_xla2
shell: bash
Expand Down
62 changes: 25 additions & 37 deletions experimental/torch_xla2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

Currently this is only source-installable. Requires Python version >= 3.10.

### NOTE:
### NOTE:

Please don't install torch-xla from instructions in
https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md .
In particular, the following are not needed:
Expand All @@ -18,57 +19,44 @@ TorchXLA2 and torch-xla have different installation instructions, please follow
the instructions below from scratch (fresh venv / conda environment.)


### 1. Install dependencies
### 1. Installing `torch_xla2`

#### 1.0 (optional) Make a virtualenv / conda env, and activate it.
#### 1.0 (recommended) Make a virtualenv / conda env

```bash
conda create --name <your_name> python=3.10
conda activate <your_name>
```
Or,
```bash
python -m venv create my_venv
source my_venv/bin/activate
```
If you are using VSCode, then [you can create a new environment from
UI](https://code.visualstudio.com/docs/python/environments). Select the
`dev-requirements.txt` when asked to install project dependencies.

#### 1.1 Install torch CPU, even if your device has GPU or TPU:
Otherwise create a new environment from the command line.

```bash
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
```

Or, follow official instructions in [pytorch.org](https://pytorch.org/get-started/locally/) to install for your OS.

#### 1.2 Install Jax for either GPU or TPU

If you are using Google Cloud TPU, then
```bash
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
# Option 1: venv
python -m venv create my_venv
source my_venv/bin/activate

If you are using a machine with NVidia GPU:
# Option 2: conda
conda create --name <your_name> python=3.10
conda activate <your_name>

```bash
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Either way, install the dev requirements.
pip install -r dev-requirements.txt
```

If you are using a CPU-only machine:
```bash
pip install --upgrade "jax[cpu]"
```
Note: `dev-requirements.txt` will install the CPU-only version of PyTorch.

Or, follow the official instructions in https://jax.readthedocs.io/en/latest/installation.html to install for your OS or Device.
#### 1.1 Install this package

#### 1.3 Install this package
Install `torch_xla2` from source for your platform:

```bash
pip install -e .
pip install -e .[cpu]
pip install -e .[cuda]
pip install -e .[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
```

#### 1.4 (optional) verify installation by running tests
#### 1.2 (optional) verify installation by running tests

```bash
pip install -r test_requirements.txt
pip install -r test-requirements.txt
pytest test
```
```
10 changes: 2 additions & 8 deletions experimental/torch_xla2/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
absl-py==2.0.0
flatbuffers==23.5.26
jax==0.4.23
jaxlib==0.4.23
pytest
tensorflow
-f https://download.pytorch.org/whl/torch
torch==2.2.1+cpu
immutabledict
sentencepiece
ruff~=0.3.5
31 changes: 16 additions & 15 deletions experimental/torch_xla2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@
requires = ["hatchling"]
build-backend = "hatchling.build"


[project]
version = "0.0.1"
name = "torch_xla2"
dependencies = [
"absl-py",
"flatbuffers",
"immutabledict",
"jax>=0.4.24",
"pytest",
"tensorflow",

# Note: Exclude these because otherwise on pip install .
# pip will install libs from pypi which is the GPU version
# of these libs.
# We most likely need CPU version of torch and TPU version of
# jax. So it's best for users to install them by hand
# See more at README.md
# "jax>=0.4.24",
# "jaxlib>=0.4.24",
# "torch",
"tensorflow-cpu",
# Developers should install `dev-requirements.txt` first
"torch>=2.2.1",
]
requires-python = ">=3.10"
license = {file = "LICENSE"}

[project.optional-dependencies]
cpu = ["jax[cpu]"]
# Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html`
tpu = ["jax[tpu]"]
cuda = ["jax[cuda12]"]

[tool.pytest.ini_options]
addopts="-n auto"

requires-python = ">=3.10"
license = {file = "LICENSE"}
[tool.ruff]
line-length = 80
indent-width = 2
5 changes: 5 additions & 0 deletions experimental/torch_xla2/test-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-r dev-requirements.txt
pytest
pytest-xdist
sentencepiece
expecttest
5 changes: 0 additions & 5 deletions experimental/torch_xla2/test_requirements.txt

This file was deleted.

Loading