Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Building alpa/jaxlib for TPU #924

Closed
Lime-Cakes opened this issue May 7, 2023 · 2 comments
Closed

Building alpa/jaxlib for TPU #924

Lime-Cakes opened this issue May 7, 2023 · 2 comments

Comments

@Lime-Cakes
Copy link

I'm currently trying to run some benchmark of alpa on TPUv2 and TPUv3. I ran into problem trying to install alpa on TPU.

First I tried following the guide for installation. Since it didn't mention TPU, I started by install tpu and google's tpulib first using the following command:
pip install jax[tpu]==0.3.22 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Then I installed alpa following the guide, using the prebuilt jaxlib wheels from https://alpa-projects.github.io/wheels.html (The original jaxlib was first uninstalled)
pip3 install jaxlib==0.3.22 --no-index -f https://alpa-projects.github.io/wheels.html
However, after installing jaxlib, TPU can no longer by detected as a backend. The error I got was:
RuntimeError: Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client' (set JAX_PLATFORMS='' to automatically choose an available backend)

It seemed the prebuilt wheels do not support TPU, so I tried building jaxlib myself, following the instructions in Method 2: Install from Source (https://alpa.ai/install.html#method-2-install-from-source)

The only thing I changed was changing --enable_cuda to --enable_tpu. I tested on latest version and checking out v0.2.3 of alpa. Both failed. The error was:
ERROR: /home/user/.cache/bazel/_bazel_user/8f8b2d573666bb995d6273003c450110/external/org_tensorflow/tensorflow/compiler/xla/service/BUILD:7141:11: Compiling tensorflow/compiler/xla/service/pass_context.cc failed: (Exit 1): gcc failed: error executing command (cd /home/user/.cache/bazel/_bazel_user/8f8b2d573666bb995d6273003c450110/execroot/__main__ && \ exec env - \ LD_LIBRARY_PATH=:/usr/local/lib \ PATH=/home/user/mambaforge/envs/a-p39/bin:/home/user/.local/bin:/home/user/mambaforge/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin \ PWD=/proc/self/cwd \ /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections -fdata-sections '-std=c++0x' -MD -MF bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/service/_objs/pass_context/pass_context.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/service/_objs/pass_context/pass_context.pic.o' -fPIC -DEIGEN_MPL2_ONLY '-DEIGEN_MAX_ALIGN_BYTES=64' -DHAVE_SYS_UIO_H -DTF_USE_SNAPPY -iquote external/org_tensorflow -iquote bazel-out/k8-opt/bin/external/org_tensorflow -iquote external/eigen_archive -iquote bazel-out/k8-opt/bin/external/eigen_archive -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/k8-opt/bin/external/nsync -iquote external/double_conversion -iquote bazel-out/k8-opt/bin/external/double_conversion -iquote external/com_google_protobuf -iquote bazel-out/k8-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/k8-opt/bin/external/zlib -iquote external/snappy -iquote bazel-out/k8-opt/bin/external/snappy -iquote external/com_googlesource_code_re2 -iquote bazel-out/k8-opt/bin/external/com_googlesource_code_re2 -iquote external/gif -iquote bazel-out/k8-opt/bin/external/gif -iquote external/libjpeg_turbo -iquote bazel-out/k8-opt/bin/external/libjpeg_turbo -iquote external/farmhash_archive -iquote bazel-out/k8-opt/bin/external/farmhash_archive -iquote external/fft2d -iquote bazel-out/k8-opt/bin/external/fft2d -iquote external/highwayhash -iquote bazel-out/k8-opt/bin/external/highwayhash -iquote external/pybind11 -iquote bazel-out/k8-opt/bin/external/pybind11 -iquote external/local_config_python -iquote bazel-out/k8-opt/bin/external/local_config_python -Ibazel-out/k8-opt/bin/external/pybind11/_virtual_includes/pybind11 -isystem external/org_tensorflow/third_party/eigen3/mkl_include -isystem bazel-out/k8-opt/bin/external/org_tensorflow/third_party/eigen3/mkl_include -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/gif -isystem bazel-out/k8-opt/bin/external/gif -isystem external/farmhash_archive/src -isystem bazel-out/k8-opt/bin/external/farmhash_archive/src -isystem external/pybind11/include -isystem bazel-out/k8-opt/bin/external/pybind11/include -isystem external/local_config_python/python_include -isystem bazel-out/k8-opt/bin/external/local_config_python/python_include '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' -mavx '-std=c++17' -fno-canonical-system-headers -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -c external/org_tensorflow/tensorflow/compiler/xla/service/pass_context.cc -o bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/service/_objs/pass_context/pass_context.pic.o)

@Lime-Cakes
Copy link
Author

The full build log is located here https://gist.github.com/Lime-Cakes/a3d738c4cd0ebff97a4b2348899ca01e

@ZYHowell
Copy link
Collaborator

ZYHowell commented May 15, 2023

We've tested auto-sharding on TPU by building our own jaxlib(see the build from source section in our document and replace --enable_cuda with --enable_tpu). However, in case the pjit.auto uses the same code as our auto-sharding, we encourage you to use that one instead. For the full 3d parallelism, we have no plan to support it. Please check #433 for more detail.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants