From 68728ea6d0be3a6f5b52108edf315ba2744e3434 Mon Sep 17 00:00:00 2001 From: Alvaro Moran <6949769+tengomucho@users.noreply.github.com> Date: Mon, 14 Oct 2024 10:21:11 +0200 Subject: [PATCH] Add CLI to install dependencies (#104) * feat(cli): added commands to install dependencies Some optimum-tpu dependencies are not indexed on pypi, so a cli is povided to install those. * chore(install): update Jetstream Pytorch install procedure Use the python CLI script instead of the shell script. * chore(README): add tests workflow badge * fix(ci): downgrade ubuntu version in doc CI Ubuntu latest image comes with Python 12, that is not compatible with torch xla yet. --- .github/workflows/doc-build.yml | 2 +- .github/workflows/doc-pr-build.yml | 2 +- Makefile | 6 +- README.md | 6 +- install-jetstream-pt.sh | 18 ----- optimum/tpu/cli.py | 113 +++++++++++++++++++++++++++++ pyproject.toml | 6 +- 7 files changed, 124 insertions(+), 29 deletions(-) delete mode 100644 install-jetstream-pt.sh create mode 100644 optimum/tpu/cli.py diff --git a/.github/workflows/doc-build.yml b/.github/workflows/doc-build.yml index 152fd5c3..21a74991 100644 --- a/.github/workflows/doc-build.yml +++ b/.github/workflows/doc-build.yml @@ -16,7 +16,7 @@ on: jobs: build_documentation: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: COMMIT_SHA: ${{ github.event.pull_request.head.sha }} PR_NUMBER: ${{ github.event.number }} diff --git a/.github/workflows/doc-pr-build.yml b/.github/workflows/doc-pr-build.yml index f320f295..2e16867b 100644 --- a/.github/workflows/doc-pr-build.yml +++ b/.github/workflows/doc-pr-build.yml @@ -15,7 +15,7 @@ concurrency: jobs: build_documentation: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: COMMIT_SHA: ${{ github.event.pull_request.head.sha }} PR_NUMBER: ${{ github.event.number }} diff --git a/Makefile b/Makefile index 843fd143..131a378f 100644 --- a/Makefile +++ b/Makefile @@ -90,11 +90,7 @@ tgi_server: VERSION=${VERSION} TGI_VERSION=${TGI_VERSION} make -C text-generation-inference/server gen-server jetstream_requirements: - bash install-jetstream-pt.sh - python -m pip install .[jetstream-pt] \ - -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ - -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ - -f https://storage.googleapis.com/libtpu-releases/index.html + python optimum/tpu/cli.py install-jetstream-pt --force tgi_test_jetstream: test_installs jetstream_requirements tgi_server find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \ diff --git a/README.md b/README.md index f6ff2cf1..830fe735 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Optimum-TPU [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://huggingface.co/docs/optimum/index) [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) - +[![Optimum TPU / Test TGI on TPU](https://github.com/huggingface/optimum-tpu/actions/workflows/test-pytorch-xla-tpu-tgi.yml/badge.svg)](https://github.com/huggingface/optimum-tpu/actions/workflows/test-pytorch-xla-tpu-tgi.yml) [Tensor Processing Units (TPU)](https://cloud.google.com/tpu) are AI accelerator made by Google to optimize @@ -49,10 +49,10 @@ Please see the [TGI specific documentation](text-generation-inference) on how to ### JetStream Pytorch Engine -`optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated command: +`optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated CLI command: ```shell -make jetstream_requirements +optimum-tpu install-jetstream-pytorch ``` To enable the support, export the environment variable `JETSTREAM_PT=1`. diff --git a/install-jetstream-pt.sh b/install-jetstream-pt.sh deleted file mode 100644 index d0db2abf..00000000 --- a/install-jetstream-pt.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -THIS_DIR=$(dirname "$0") - -deps_dir=deps -rm -rf $deps_dir -mkdir -p $deps_dir - - -# install torch cpu to avoid GPU requirements -pip install -r $THIS_DIR/requirements.txt -cd $deps_dir -git clone https://github.com/google/jetstream-pytorch.git -cd jetstream-pytorch -git checkout ec4ac8f6b180ade059a2284b8b7d843b3cab0921 -git submodule update --init --recursive -# We cannot install in a temporary directory because the directory should not be deleted after the script finishes, -# because it will install its dependendencies from that directory. -pip install -e . diff --git a/optimum/tpu/cli.py b/optimum/tpu/cli.py new file mode 100644 index 00000000..a380034b --- /dev/null +++ b/optimum/tpu/cli.py @@ -0,0 +1,113 @@ +import importlib.util +import os +import shutil +import subprocess +import sys +from pathlib import Path + +import click +import typer + + +TORCH_VER = "2.4.0" +JETSTREAM_PT_VER = "ec4ac8f6b180ade059a2284b8b7d843b3cab0921" +DEFAULT_DEPS_PATH = os.path.join(Path.home(), ".jetstream-deps") + +app = typer.Typer() + + +def _check_module(module_name: str): + spec = importlib.util.find_spec(module_name) + return spec is not None + + +def _run(cmd: str): + split_cmd = cmd.split() + subprocess.check_call(split_cmd) + + +def _install_torch_cpu(): + # install torch CPU version to avoid installing CUDA dependencies + _run(sys.executable + f" -m pip install torch=={TORCH_VER} --index-url https://download.pytorch.org/whl/cpu") + + +@app.command() +def install_pytorch_xla( + force: bool = False, +): + """ + Installs PyTorch XLA with TPU support. + + Args: + force (bool): When set, force reinstalling even if Pytorch XLA is already installed. + """ + if not force and _check_module("torch") and _check_module("torch_xla"): + typer.confirm( + "PyTorch XLA is already installed. Do you want to reinstall it?", + default=False, + abort=True, + ) + _install_torch_cpu() + _run( + sys.executable + + f" -m pip install torch-xla[tpu]=={TORCH_VER} -f https://storage.googleapis.com/libtpu-releases/index.html" + ) + click.echo() + click.echo(click.style("PyTorch XLA has been installed.", bold=True)) + + +@app.command() +def install_jetstream_pytorch( + deps_path: str = DEFAULT_DEPS_PATH, + yes: bool = False, +): + """ + Installs Jetstream Pytorch with TPU support. + + Args: + deps_path (str): Path where Jetstream Pytorch dependencies will be installed. + yes (bool): When set, proceed installing without asking questions. + """ + if not _check_module("torch"): + _install_torch_cpu() + if not yes and _check_module("jetstream_pt") and _check_module("torch_xla2"): + typer.confirm( + "Jetstream Pytorch is already installed. Do you want to reinstall it?", + default=False, + abort=True, + ) + + jetstream_repo_dir = os.path.join(deps_path, "jetstream-pytorch") + if not yes and os.path.exists(jetstream_repo_dir): + typer.confirm( + f"Directory {jetstream_repo_dir} already exists. Do you want to delete it and reinstall Jetstream Pytorch?", + default=False, + abort=True, + ) + shutil.rmtree(jetstream_repo_dir, ignore_errors=True) + # Create the directory if it does not exist + os.makedirs(deps_path, exist_ok=True) + # Clone and install Jetstream Pytorch + os.chdir(deps_path) + _run("git clone https://github.com/google/jetstream-pytorch.git") + os.chdir("jetstream-pytorch") + _run(f"git checkout {JETSTREAM_PT_VER}") + _run("git submodule update --init --recursive") + # We cannot install in a temporary directory because the directory should not be deleted after the script finishes, + # because it will install its dependendencies from that directory. + _run(sys.executable + " -m pip install -e .") + + _run( + sys.executable + + f" -m pip install torch_xla[pallas]=={TORCH_VER} " + + " -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html" + + " -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html" + + " -f https://storage.googleapis.com/libtpu-releases/index.html" + ) + # Install PyTorch XLA pallas + click.echo() + click.echo(click.style("Jetstream Pytorch has been installed.", bold=True)) + + +if __name__ == "__main__": + sys.exit(app()) diff --git a/pyproject.toml b/pyproject.toml index 13b728ae..b9d4c9d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "transformers == 4.41.1", "torch == 2.4.0", "torch-xla[tpu] == 2.4.0", + 'typer == 0.6.1', "loguru == 0.6.0", "sentencepiece == 0.2.0", ] @@ -103,4 +104,7 @@ filterwarnings = [ "ignore:`do_sample` is set", "ignore:Device capability of jax", "ignore:`tensorflow` can conflict", -] \ No newline at end of file +] + +[project.scripts] +optimum-tpu = "optimum.tpu.cli:app"