From 61bd1d23392d46652986fd2b8a1dff3ac2728158 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 23 Jan 2024 10:00:37 -0800 Subject: [PATCH] Improve build process for PJRT plugins (#6314) --- .gitignore | 2 +- build_util.py | 81 +++++++++++++++++ plugins/cpu/README.md | 9 +- plugins/cpu/build_util.py | 1 + plugins/cpu/pyproject.toml | 2 +- plugins/cpu/setup.py | 7 ++ plugins/cpu/torch_xla_cpu_plugin/__init__.py | 5 +- plugins/cuda/README.md | 11 +-- plugins/cuda/build_util.py | 1 + plugins/cuda/pyproject.toml | 8 +- plugins/cuda/setup.py | 7 ++ .../cuda/torch_xla_cuda_plugin/__init__.py | 8 +- setup.py | 89 +++---------------- 13 files changed, 131 insertions(+), 100 deletions(-) create mode 100644 build_util.py create mode 120000 plugins/cpu/build_util.py create mode 100644 plugins/cpu/setup.py create mode 120000 plugins/cuda/build_util.py create mode 100644 plugins/cuda/setup.py diff --git a/.gitignore b/.gitignore index e2f9ebf6d15..d58c8bf4a15 100644 --- a/.gitignore +++ b/.gitignore @@ -26,7 +26,7 @@ torch_xla/csrc/version.cpp # Build system temporary files -/bazel-* +bazel-* # Clangd cache directory .cache/* diff --git a/build_util.py b/build_util.py new file mode 100644 index 00000000000..78e4bd5e453 --- /dev/null +++ b/build_util.py @@ -0,0 +1,81 @@ +import os +from typing import Iterable +import subprocess +import sys +import shutil + + +def check_env_flag(name: str, default: str = '') -> bool: + return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] + + +def bazel_options_from_env() -> Iterable[str]: + bazel_flags = [] + + if check_env_flag('DEBUG'): + bazel_flags.append('--config=dbg') + + if check_env_flag('TPUVM_MODE'): + bazel_flags.append('--config=tpu') + + gcloud_key_file = os.getenv('GCLOUD_SERVICE_KEY_FILE', default='') + # Remote cache authentication. + if gcloud_key_file: + # Temporary workaround to allow PRs from forked repo to run CI. See details at (#5259). + # TODO: Remove the check once self-hosted GHA workers are available to CPU/GPU CI. + gcloud_key_file_size = os.path.getsize(gcloud_key_file) + if gcloud_key_file_size > 1: + bazel_flags.append('--google_credentials=%s' % gcloud_key_file) + bazel_flags.append('--config=remote_cache') + else: + if check_env_flag('BAZEL_REMOTE_CACHE'): + bazel_flags.append('--config=remote_cache') + + cache_silo_name = os.getenv('SILO_NAME', default='dev') + if cache_silo_name: + bazel_flags.append('--remote_default_exec_properties=cache-silo-key=%s' % + cache_silo_name) + + if check_env_flag('BUILD_CPP_TESTS', default='0'): + bazel_flags.append('//test/cpp:all') + bazel_flags.append('//torch_xla/csrc/runtime:all') + + bazel_jobs = os.getenv('BAZEL_JOBS', default='') + if bazel_jobs: + bazel_flags.append('--jobs=%s' % bazel_jobs) + + # Build configuration. + if check_env_flag('BAZEL_VERBOSE'): + bazel_flags.append('-s') + if check_env_flag('XLA_CUDA'): + bazel_flags.append('--config=cuda') + if check_env_flag('XLA_CPU_USE_ACL'): + bazel_flags.append('--config=acl') + + return bazel_flags + + +def bazel_build(bazel_target: str, + destination_dir: str, + options: Iterable[str] = []): + bazel_argv = [ + 'bazel', 'build', bazel_target, + f"--symlink_prefix={os.path.join(os.getcwd(), 'bazel-')}" + ] + + # Remove duplicated flags because they confuse bazel + flags = set(bazel_options_from_env() + options) + bazel_argv.extend(flags) + + print(' '.join(bazel_argv), flush=True) + subprocess.check_call(bazel_argv, stdout=sys.stdout, stderr=sys.stderr) + + target_path = bazel_target.replace('@xla//', 'external/xla/').replace( + '//', '').replace(':', '/') + output_path = os.path.join('bazel-bin', target_path) + output_filename = os.path.basename(output_path) + + if not os.path.exists(destination_dir): + os.makedirs(destination_dir) + + shutil.copyfile(output_path, os.path.join(destination_dir, output_filename)) diff --git a/plugins/cpu/README.md b/plugins/cpu/README.md index b3fedc24d8d..21398769b04 100644 --- a/plugins/cpu/README.md +++ b/plugins/cpu/README.md @@ -10,15 +10,10 @@ repository (see `bazel build` command below). ## Building ```bash -# Build PJRT plugin -bazel build //plugins/cpu:pjrt_c_api_cpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1 -# Copy to package dir -cp bazel-bin/plugins/cpu/pjrt_c_api_cpu_plugin.so plugins/cpu/torch_xla_cpu_plugin/ - # Build wheel -pip wheel plugins/cpu +pip wheel plugins/cpu --no-build-isolation -v # Or install directly -pip install plugins/cpu +pip install plugins/cpu --no-build-isolation -v ``` ## Usage diff --git a/plugins/cpu/build_util.py b/plugins/cpu/build_util.py new file mode 120000 index 00000000000..219f486130c --- /dev/null +++ b/plugins/cpu/build_util.py @@ -0,0 +1 @@ +../../build_util.py \ No newline at end of file diff --git a/plugins/cpu/pyproject.toml b/plugins/cpu/pyproject.toml index 9359a0dc7b5..770d5cb5fb1 100644 --- a/plugins/cpu/pyproject.toml +++ b/plugins/cpu/pyproject.toml @@ -12,7 +12,7 @@ description = "CPU PJRT Plugin for testing only" requires-python = ">=3.8" [tool.setuptools.package-data] -torch_xla_cpu_plugin = ["*.so"] +torch_xla_cpu_plugin = ["lib/*.so"] [project.entry-points."torch_xla.plugins"] cpu = "torch_xla_cpu_plugin:CpuPlugin" diff --git a/plugins/cpu/setup.py b/plugins/cpu/setup.py new file mode 100644 index 00000000000..9182c86a635 --- /dev/null +++ b/plugins/cpu/setup.py @@ -0,0 +1,7 @@ +import build_util +import setuptools + +build_util.bazel_build('//plugins/cpu:pjrt_c_api_cpu_plugin.so', + 'torch_xla_cpu_plugin/lib') + +setuptools.setup() diff --git a/plugins/cpu/torch_xla_cpu_plugin/__init__.py b/plugins/cpu/torch_xla_cpu_plugin/__init__.py index da7a3234267..142a6b682e2 100644 --- a/plugins/cpu/torch_xla_cpu_plugin/__init__.py +++ b/plugins/cpu/torch_xla_cpu_plugin/__init__.py @@ -2,9 +2,12 @@ from torch_xla.experimental import plugins from torch_xla._internal import tpu + class CpuPlugin(plugins.DevicePlugin): + def library_path(self) -> str: - return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_cpu_plugin.so') + return os.path.join( + os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so') def physical_chip_count(self) -> int: return 1 diff --git a/plugins/cuda/README.md b/plugins/cuda/README.md index f5a2647f6e6..89d14c737ea 100644 --- a/plugins/cuda/README.md +++ b/plugins/cuda/README.md @@ -7,15 +7,10 @@ repository (see `bazel build` command below). ## Building ```bash -# Build PJRT plugin -bazel build @xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1 --config=cuda -# Copy to package dir -cp bazel-bin/external/xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so plugins/cuda/torch_xla_cuda_plugin - # Build wheel -pip wheel plugins/cuda +pip wheel plugins/cuda --no-build-isolation -v # Or install directly -pip install plugins/cuda +pip install plugins/cuda --no-build-isolation -v ``` ## Usage @@ -34,7 +29,7 @@ import torch_xla.runtime as xr # Use dynamic plugin instead of built-in CUDA support plugins.use_dynamic_plugins() -plugins.register_plugin('CUDA', torch_xla_cuda_plugin.GpuPlugin()) +plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin()) xr.set_device_type('CUDA') print(xm.xla_device()) diff --git a/plugins/cuda/build_util.py b/plugins/cuda/build_util.py new file mode 120000 index 00000000000..219f486130c --- /dev/null +++ b/plugins/cuda/build_util.py @@ -0,0 +1 @@ +../../build_util.py \ No newline at end of file diff --git a/plugins/cuda/pyproject.toml b/plugins/cuda/pyproject.toml index 306b30495ea..fd8bbf59f6c 100644 --- a/plugins/cuda/pyproject.toml +++ b/plugins/cuda/pyproject.toml @@ -6,13 +6,13 @@ build-backend = "setuptools.build_meta" name = "torch_xla_cuda_plugin" version = "0.0.1" authors = [ - {name = "Will Cromar", email = "wcromar@google.com"}, + {name = "PyTorch/XLA Dev Team", email = "pytorch-xla@googlegroups.com"}, ] -description = "CUDA Plugin" +description = "PyTorch/XLA CUDA Plugin" requires-python = ">=3.8" [tool.setuptools.package-data] -torch_xla_cuda_plugin = ["*.so"] +torch_xla_cuda_plugin = ["lib/*.so"] [project.entry-points."torch_xla.plugins"] -gpu = "torch_xla_cuda_plugin:GpuPlugin" +cuda = "torch_xla_cuda_plugin:CudaPlugin" diff --git a/plugins/cuda/setup.py b/plugins/cuda/setup.py new file mode 100644 index 00000000000..8f6aaf00b74 --- /dev/null +++ b/plugins/cuda/setup.py @@ -0,0 +1,7 @@ +import build_util +import setuptools + +build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so', + 'torch_xla_cuda_plugin/lib', ['--config=cuda']) + +setuptools.setup() diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py index f10a412bfaa..d8d159de2ef 100644 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -1,10 +1,12 @@ import os from torch_xla.experimental import plugins -from torch_xla._internal import tpu -class GpuPlugin(plugins.DevicePlugin): + +class CudaPlugin(plugins.DevicePlugin): + def library_path(self) -> str: - return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so') + return os.path.join( + os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so') def physical_chip_count(self) -> int: # TODO: default to actual device count diff --git a/setup.py b/setup.py index 094926bf427..b997490f36c 100644 --- a/setup.py +++ b/setup.py @@ -46,30 +46,23 @@ # CXX_ABI="" # value for cxx_abi flag; if empty, it is inferred from `torch._C`. # -from __future__ import print_function - from setuptools import setup, find_packages, distutils, Extension, command -from setuptools.command import develop -from torch.utils.cpp_extension import BuildExtension +from setuptools.command import develop, build_ext import posixpath import contextlib import distutils.ccompiler import distutils.command.clean -import glob -import inspect -import multiprocessing -import multiprocessing.pool import os -import platform -import re import requests import shutil import subprocess import sys import tempfile -import torch import zipfile +import build_util +import torch + base_dir = os.path.dirname(os.path.abspath(__file__)) _libtpu_version = '0.1.dev20240118' @@ -82,10 +75,6 @@ def _get_build_mode(): return sys.argv[i] -def _check_env_flag(name, default=''): - return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] - - def get_git_head_sha(base_dir): xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=base_dir).decode('ascii').strip() @@ -101,7 +90,7 @@ def get_git_head_sha(base_dir): def get_build_version(xla_git_sha): version = os.getenv('TORCH_XLA_VERSION', '2.2.0') - if _check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): + if build_util.check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): try: version += '+git' + xla_git_sha[:7] except Exception: @@ -135,7 +124,7 @@ def maybe_bundle_libtpu(base_dir): with contextlib.suppress(FileNotFoundError): os.remove(libtpu_path) - if not _check_env_flag('BUNDLE_LIBTPU', '0'): + if not build_util.check_env_flag('BUNDLE_LIBTPU', '0'): return try: @@ -201,20 +190,6 @@ def run(self): # Copy libtpu.so into torch_xla/lib maybe_bundle_libtpu(base_dir) -DEBUG = _check_env_flag('DEBUG') -IS_DARWIN = (platform.system() == 'Darwin') -IS_WINDOWS = sys.platform.startswith('win') -IS_LINUX = (platform.system() == 'Linux') -GCLOUD_KEY_FILE = os.getenv('GCLOUD_SERVICE_KEY_FILE', default='') -CACHE_SILO_NAME = os.getenv('SILO_NAME', default='dev') -BAZEL_JOBS = os.getenv('BAZEL_JOBS', default='') - -extra_compile_args = [] -cxx_abi = os.getenv( - 'CXX_ABI', default='') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None) -if cxx_abi is not None: - extra_compile_args.append(f'-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}') - class BazelExtension(Extension): """A C/C++ extension that is defined as a Bazel BUILD target.""" @@ -230,7 +205,7 @@ def __init__(self, bazel_target): Extension.__init__(self, ext_name, sources=[]) -class BuildBazelExtension(command.build_ext.build_ext): +class BuildBazelExtension(build_ext.build_ext): """A command that runs Bazel to build a C/C++ extension.""" def run(self): @@ -246,49 +221,13 @@ def bazel_build(self, ext): 'bazel', 'build', ext.bazel_target, f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}" ] - for opt in extra_compile_args: - bazel_argv.append("--cxxopt={}".format(opt)) - - # Debug build. - if DEBUG: - bazel_argv.append('--config=dbg') - - if _check_env_flag('TPUVM_MODE'): - bazel_argv.append('--config=tpu') - - # Remote cache authentication. - if GCLOUD_KEY_FILE: - # Temporary workaround to allow PRs from forked repo to run CI. See details at (#5259). - # TODO: Remove the check once self-hosted GHA workers are available to CPU/GPU CI. - gclout_key_file_size = os.path.getsize(GCLOUD_KEY_FILE) - if gclout_key_file_size > 1: - bazel_argv.append('--google_credentials=%s' % GCLOUD_KEY_FILE) - bazel_argv.append('--config=remote_cache') - else: - if _check_env_flag('BAZEL_REMOTE_CACHE'): - bazel_argv.append('--config=remote_cache') - if CACHE_SILO_NAME: - bazel_argv.append('--remote_default_exec_properties=cache-silo-key=%s' % - CACHE_SILO_NAME) - - if _check_env_flag('BUILD_CPP_TESTS', default='0'): - bazel_argv.append('//test/cpp:all') - bazel_argv.append('//torch_xla/csrc/runtime:all') - - if BAZEL_JOBS: - bazel_argv.append('--jobs=%s' % BAZEL_JOBS) - - # Build configuration. - if _check_env_flag('BAZEL_VERBOSE'): - bazel_argv.append('-s') - if _check_env_flag('XLA_CUDA'): - bazel_argv.append('--config=cuda') - if _check_env_flag('XLA_CPU_USE_ACL'): - bazel_argv.append('--config=acl') - - if IS_WINDOWS: - for library_dir in self.library_dirs: - bazel_argv.append('--linkopt=/LIBPATH:' + library_dir) + + cxx_abi = os.getenv('CXX_ABI') or getattr(torch._C, + '_GLIBCXX_USE_CXX11_ABI', None) + if cxx_abi is not None: + bazel_argv.append(f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}') + + bazel_argv.extend(build_util.bazel_options_from_env()) self.spawn(bazel_argv)