Skip to content

Commit

Permalink
update inst, remove some prints (#6866)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Apr 1, 2024
1 parent c1773b6 commit 88f93fe
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ jobs:
run: |
pip install pytest absl-py jax[cpu] flatbuffers tensorflow
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install -r dev-requirements.txt
pip install -r test_requirements.txt
pip install -e .
- name: Run tests
working-directory: experimental/torch_xla2
shell: bash
run: |
pytest test/
pytest test/
73 changes: 72 additions & 1 deletion experimental/torch_xla2/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,74 @@
# torchxla2

This directory contains things that are in the top-level git repository
## Install

Currently this is only source-installable. Requires Python version >= 3.10.

### NOTE:
Please don't install torch-xla from instructions in
https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md .
In particular, the following are not needed:

* There is no need to build pytorch/pytorch from source.
* There is no need to clone pytorch/xla project inside of pytorch/pytorch
git checkout.


TorchXLA2 and torch-xla have different installation instructions, please follow
the instructions below from scratch (fresh venv / conda environment.)


### 1. Install dependencies

#### 1.0 (optional) Make a virtualenv / conda env, and activate it.

```bash
conda create --name <your_name> python=3.10
conda activate <your_name>
```
Or,
```bash
python -m venv create my_venv
source my_venv/bin/activate
```

#### 1.1 Install torch CPU, even if your device has GPU or TPU:

```bash
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
```

Or, follow official instructions in [pytorch.org](https://pytorch.org/get-started/locally/) to install for your OS.

#### 1.2 Install Jax for either GPU or TPU

If you are using Google Cloud TPU, then
```bash
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

If you are using a machine with NVidia GPU:

```bash
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

If you are using a CPU-only machine:
```bash
pip install --upgrade "jax[cpu]"
```

Or, follow the official instructions in https://jax.readthedocs.io/en/latest/installation.html to install for your OS or Device.

#### 1.3 Install this package

```bash
pip install -e .
```

#### 1.4 (optional) verify installation by running tests

```bash
pip install -r test_requirements.txt
pytest test
```
16 changes: 13 additions & 3 deletions experimental/torch_xla2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,22 @@ name = "torch_xla2"
dependencies = [
"absl-py",
"flatbuffers",
"jax>=0.4.24",
"jaxlib>=0.4.24",
"pytest",
"tensorflow",
"torch",

# Note: Exclude these because otherwise on pip install .
# pip will install libs from pypi which is the GPU version
# of these libs.
# We most likely need CPU version of torch and TPU version of
# jax. So it's best for users to install them by hand
# See more at README.md
# "jax>=0.4.24",
# "jaxlib>=0.4.24",
# "torch",
]

[tool.pytest.ini_options]
addopts="-n auto"

requires-python = ">=3.10"
license = {file = "LICENSE"}
7 changes: 0 additions & 7 deletions experimental/torch_xla2/requirements.txt

This file was deleted.

4 changes: 4 additions & 0 deletions experimental/torch_xla2/test_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pytest
immutabledict
sentencepiece
pytest-xdist
19 changes: 13 additions & 6 deletions experimental/torch_xla2/torch_xla2/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@

aten = torch._ops.ops.aten

@register_decomposition(aten.reflection_pad1d)
@register_decomposition(aten.reflection_pad2d)
@register_decomposition(aten.reflection_pad3d)
def _try_register(op, impl):
try:
register_decomposition(op)(impl)
except:
pass

@out_wrapper()
def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def idx(left, middle, right):
Expand All @@ -40,9 +43,10 @@ def idx(left, middle, right):
idx,
)

_try_register(aten.reflection_pad1d, _reflection_pad)
_try_register(aten.reflection_pad2d, _reflection_pad)
_try_register(aten.reflection_pad3d, _reflection_pad)

@register_decomposition(aten.replication_pad1d)
@register_decomposition(aten.replication_pad3d)
@out_wrapper()
def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def idx(left, middle, right):
Expand Down Expand Up @@ -83,4 +87,7 @@ def _reflection_or_replication_pad(
# convert output to correct memory format, if necessary
memory_format = utils.suggest_memory_format(result)
result = result.contiguous(memory_format=memory_format)
return result
return result

_try_register(aten.replication_pad1d, _replication_pad)
_try_register(aten.replication_pad3d, _replication_pad)

0 comments on commit 88f93fe

Please sign in to comment.