diff --git a/.github/workflows/torch_xla2.yml b/.github/workflows/torch_xla2.yml index 7c5a88bf430..441addad422 100644 --- a/.github/workflows/torch_xla2.yml +++ b/.github/workflows/torch_xla2.yml @@ -34,10 +34,8 @@ jobs: shell: bash working-directory: experimental/torch_xla2 run: | - pip install pytest absl-py jax[cpu] flatbuffers tensorflow - pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install -r test_requirements.txt - pip install -e . + pip install -r test-requirements.txt + pip install -e .[cpu] - name: Run tests working-directory: experimental/torch_xla2 shell: bash diff --git a/experimental/torch_xla2/README.md b/experimental/torch_xla2/README.md index 0dccde701d6..fba08f40498 100644 --- a/experimental/torch_xla2/README.md +++ b/experimental/torch_xla2/README.md @@ -4,7 +4,8 @@ Currently this is only source-installable. Requires Python version >= 3.10. -### NOTE: +### NOTE: + Please don't install torch-xla from instructions in https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md . In particular, the following are not needed: @@ -18,153 +19,44 @@ TorchXLA2 and torch-xla have different installation instructions, please follow the instructions below from scratch (fresh venv / conda environment.) -### 1. Install dependencies - -#### 1.0 (optional) Make a virtualenv / conda env, and activate it. - -```bash -conda create --name python=3.10 -conda activate -``` -Or, -```bash -python -m venv create my_venv -source my_venv/bin/activate -``` - -#### 1.1 Install torch CPU, even if your device has GPU or TPU: +### 1. Installing `torch_xla2` -```bash -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -``` +#### 1.0 (recommended) Make a virtualenv / conda env -Or, follow official instructions in [pytorch.org](https://pytorch.org/get-started/locally/) to install for your OS. +If you are using VSCode, then [you can create a new environment from +UI](https://code.visualstudio.com/docs/python/environments). Select the +`dev-requirements.txt` when asked to install project dependencies. -#### 1.2 Install Jax for either GPU or TPU +Otherwise create a new environment from the command line. -If you are using Google Cloud TPU, then ```bash -pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -``` +# Option 1: venv +python -m venv create my_venv +source my_venv/bin/activate -If you are using a machine with NVidia GPU: +# Option 2: conda +conda create --name python=3.10 +conda activate -```bash -pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +# Either way, install the dev requirements. +pip install -r dev-requirements.txt ``` -If you are using a CPU-only machine: -```bash -pip install --upgrade "jax[cpu]" -``` +Note: `dev-requirements.txt` will install the CPU-only version of PyTorch. -Or, follow the official instructions in https://jax.readthedocs.io/en/latest/installation.html to install for your OS or Device. +#### 1.1 Install this package -#### 1.3 Install this package +Install `torch_xla2` from source for your platform: ```bash -pip install -e . +pip install -e .[cpu] +pip install -e .[cuda] +pip install -e .[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html ``` -#### 1.4 (optional) verify installation by running tests +#### 1.2 (optional) verify installation by running tests ```bash -pip install -r test_requirements.txt +pip install -r test-requirements.txt pytest test ``` - - -## Run a model - -Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model -it can be in theory any instance of `torch.nn.Module`. - -```python -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class MyModel(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - -m = MyModel() - -# Execute this model using torch -inputs = torch.randn(3, 3, 28, 28) -print(m(inputs)) -``` - -This model `m` contains 2 parts: the weights that is stored inside of the model -and it's submodules (`nn.Linear`). - -To execute this model with `torch_xla2`; we need to move the tensors involved in compute -to `XLA` devices. This can be accomplished with `torch_xla2.tensor.move_to_device`. - -We need move both the weights and the input to xla devices: - -```python -import torch_xla2 -from torch.utils import _pytree as pytree -from torch_xla2.tensor import move_to_device - -inputs = move_to_device(inputs) -new_state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, m.state_dict()) -m.load_state_dict(new_state_dict, assign=True) - -res = m(inputs) - -print(type(res)) # outputs XLATensor2 -``` - -### Executing with jax.jit - -The above script will execute the model using eager mode Jax as backend. This -does allow executing torch models on TPU, but is often slower than what we can -achieve with `jax.jit`. - -`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array -and returns jax array) into the same function, but faster. - -We have made the `jax_jit` decorator that would accomplish the same with functions -that takes and returns `torch.Tensor`. To use this, the first step is to create -a functional version of this model: this means the parameters should be passed in -as input instead of being attributes on class: - - -```python - -def model_func(param, inputs): - return torch.func.functional_call(m, param, inputs) - -``` -Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html) -from PyTorch to replace the model -weights with `param`, then call the model. This is equivalent to: - -```python -def model_func(param, inputs): - m.load_state_dict(param) - return m(*inputs) -``` - -Now, we can apply `jax_jit` - -```python -from torch_xla2.extra import jax_jit -model_func_jitted = jax_jit(model_func) -print(model_func_jitted(new_state_dict, inputs)) -``` - - diff --git a/experimental/torch_xla2/dev-requirements.txt b/experimental/torch_xla2/dev-requirements.txt index 4a32310fbda..004a1d71ad7 100644 --- a/experimental/torch_xla2/dev-requirements.txt +++ b/experimental/torch_xla2/dev-requirements.txt @@ -1,9 +1,3 @@ -absl-py==2.0.0 -flatbuffers==23.5.26 -jax==0.4.23 -jaxlib==0.4.23 -pytest -tensorflow +-f https://download.pytorch.org/whl/torch torch==2.2.1+cpu -immutabledict -sentencepiece \ No newline at end of file +ruff~=0.3.5 diff --git a/experimental/torch_xla2/pyproject.toml b/experimental/torch_xla2/pyproject.toml index d0d2a42dec8..0c2101dbcb9 100644 --- a/experimental/torch_xla2/pyproject.toml +++ b/experimental/torch_xla2/pyproject.toml @@ -2,29 +2,30 @@ requires = ["hatchling"] build-backend = "hatchling.build" - [project] version = "0.0.1" name = "torch_xla2" dependencies = [ "absl-py", - "flatbuffers", + "immutabledict", + "jax>=0.4.24", "pytest", - "tensorflow", - - # Note: Exclude these because otherwise on pip install . - # pip will install libs from pypi which is the GPU version - # of these libs. - # We most likely need CPU version of torch and TPU version of - # jax. So it's best for users to install them by hand - # See more at README.md - # "jax>=0.4.24", - # "jaxlib>=0.4.24", - # "torch", + "tensorflow-cpu", + # Developers should install `dev-requirements.txt` first + "torch>=2.2.1", ] - requires-python = ">=3.10" license = {file = "LICENSE"} +[project.optional-dependencies] +cpu = ["jax[cpu]"] +# Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html` +tpu = ["jax[tpu]"] +cuda = ["jax[cuda12]"] + [tool.pytest.ini_options] addopts="-n auto" + +[tool.ruff] +line-length = 80 +indent-width = 2 diff --git a/experimental/torch_xla2/test-requirements.txt b/experimental/torch_xla2/test-requirements.txt new file mode 100644 index 00000000000..1deead455a1 --- /dev/null +++ b/experimental/torch_xla2/test-requirements.txt @@ -0,0 +1,5 @@ +-r dev-requirements.txt +pytest +pytest-xdist +sentencepiece +expecttest diff --git a/experimental/torch_xla2/test_requirements.txt b/experimental/torch_xla2/test_requirements.txt deleted file mode 100644 index c8596327236..00000000000 --- a/experimental/torch_xla2/test_requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -pytest -immutabledict -sentencepiece -pytest-xdist -expecttest \ No newline at end of file