Skip to content

Commit

Permalink
Add CLI to install dependencies (#104)
Browse files Browse the repository at this point in the history
* feat(cli): added commands to install dependencies

Some optimum-tpu dependencies are not indexed on pypi, so a cli is
povided to install those.

* chore(install): update Jetstream Pytorch install procedure

Use the python CLI script instead of the shell script.

* chore(README): add tests workflow badge

* fix(ci): downgrade ubuntu version in doc CI

Ubuntu latest image comes with Python 12, that is not compatible with
torch xla yet.
  • Loading branch information
tengomucho authored Oct 14, 2024
1 parent 16596de commit 68728ea
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/doc-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ on:

jobs:
build_documentation:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
env:
COMMIT_SHA: ${{ github.event.pull_request.head.sha }}
PR_NUMBER: ${{ github.event.number }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/doc-pr-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ concurrency:

jobs:
build_documentation:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
env:
COMMIT_SHA: ${{ github.event.pull_request.head.sha }}
PR_NUMBER: ${{ github.event.number }}
Expand Down
6 changes: 1 addition & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,7 @@ tgi_server:
VERSION=${VERSION} TGI_VERSION=${TGI_VERSION} make -C text-generation-inference/server gen-server

jetstream_requirements:
bash install-jetstream-pt.sh
python -m pip install .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
python optimum/tpu/cli.py install-jetstream-pt --force

tgi_test_jetstream: test_installs jetstream_requirements tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Optimum-TPU

[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://huggingface.co/docs/optimum/index)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)

[![Optimum TPU / Test TGI on TPU](https://github.com/huggingface/optimum-tpu/actions/workflows/test-pytorch-xla-tpu-tgi.yml/badge.svg)](https://github.com/huggingface/optimum-tpu/actions/workflows/test-pytorch-xla-tpu-tgi.yml)
</div>

[Tensor Processing Units (TPU)](https://cloud.google.com/tpu) are AI accelerator made by Google to optimize
Expand Down Expand Up @@ -49,10 +49,10 @@ Please see the [TGI specific documentation](text-generation-inference) on how to

### JetStream Pytorch Engine

`optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated command:
`optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated CLI command:

```shell
make jetstream_requirements
optimum-tpu install-jetstream-pytorch
```

To enable the support, export the environment variable `JETSTREAM_PT=1`.
Expand Down
18 changes: 0 additions & 18 deletions install-jetstream-pt.sh

This file was deleted.

113 changes: 113 additions & 0 deletions optimum/tpu/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import importlib.util
import os
import shutil
import subprocess
import sys
from pathlib import Path

import click
import typer


TORCH_VER = "2.4.0"
JETSTREAM_PT_VER = "ec4ac8f6b180ade059a2284b8b7d843b3cab0921"
DEFAULT_DEPS_PATH = os.path.join(Path.home(), ".jetstream-deps")

app = typer.Typer()


def _check_module(module_name: str):
spec = importlib.util.find_spec(module_name)
return spec is not None


def _run(cmd: str):
split_cmd = cmd.split()
subprocess.check_call(split_cmd)


def _install_torch_cpu():
# install torch CPU version to avoid installing CUDA dependencies
_run(sys.executable + f" -m pip install torch=={TORCH_VER} --index-url https://download.pytorch.org/whl/cpu")


@app.command()
def install_pytorch_xla(
force: bool = False,
):
"""
Installs PyTorch XLA with TPU support.
Args:
force (bool): When set, force reinstalling even if Pytorch XLA is already installed.
"""
if not force and _check_module("torch") and _check_module("torch_xla"):
typer.confirm(
"PyTorch XLA is already installed. Do you want to reinstall it?",
default=False,
abort=True,
)
_install_torch_cpu()
_run(
sys.executable
+ f" -m pip install torch-xla[tpu]=={TORCH_VER} -f https://storage.googleapis.com/libtpu-releases/index.html"
)
click.echo()
click.echo(click.style("PyTorch XLA has been installed.", bold=True))


@app.command()
def install_jetstream_pytorch(
deps_path: str = DEFAULT_DEPS_PATH,
yes: bool = False,
):
"""
Installs Jetstream Pytorch with TPU support.
Args:
deps_path (str): Path where Jetstream Pytorch dependencies will be installed.
yes (bool): When set, proceed installing without asking questions.
"""
if not _check_module("torch"):
_install_torch_cpu()
if not yes and _check_module("jetstream_pt") and _check_module("torch_xla2"):
typer.confirm(
"Jetstream Pytorch is already installed. Do you want to reinstall it?",
default=False,
abort=True,
)

jetstream_repo_dir = os.path.join(deps_path, "jetstream-pytorch")
if not yes and os.path.exists(jetstream_repo_dir):
typer.confirm(
f"Directory {jetstream_repo_dir} already exists. Do you want to delete it and reinstall Jetstream Pytorch?",
default=False,
abort=True,
)
shutil.rmtree(jetstream_repo_dir, ignore_errors=True)
# Create the directory if it does not exist
os.makedirs(deps_path, exist_ok=True)
# Clone and install Jetstream Pytorch
os.chdir(deps_path)
_run("git clone https://github.com/google/jetstream-pytorch.git")
os.chdir("jetstream-pytorch")
_run(f"git checkout {JETSTREAM_PT_VER}")
_run("git submodule update --init --recursive")
# We cannot install in a temporary directory because the directory should not be deleted after the script finishes,
# because it will install its dependendencies from that directory.
_run(sys.executable + " -m pip install -e .")

_run(
sys.executable
+ f" -m pip install torch_xla[pallas]=={TORCH_VER} "
+ " -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html"
+ " -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html"
+ " -f https://storage.googleapis.com/libtpu-releases/index.html"
)
# Install PyTorch XLA pallas
click.echo()
click.echo(click.style("Jetstream Pytorch has been installed.", bold=True))


if __name__ == "__main__":
sys.exit(app())
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"transformers == 4.41.1",
"torch == 2.4.0",
"torch-xla[tpu] == 2.4.0",
'typer == 0.6.1',
"loguru == 0.6.0",
"sentencepiece == 0.2.0",
]
Expand Down Expand Up @@ -103,4 +104,7 @@ filterwarnings = [
"ignore:`do_sample` is set",
"ignore:Device capability of jax",
"ignore:`tensorflow` can conflict",
]
]

[project.scripts]
optimum-tpu = "optimum.tpu.cli:app"

0 comments on commit 68728ea

Please sign in to comment.