Skip to content

Commit

Permalink
Improve build process for PJRT plugins (#6314)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored Jan 23, 2024
1 parent 269ebcf commit 61bd1d2
Show file tree
Hide file tree
Showing 13 changed files with 131 additions and 100 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ torch_xla/csrc/version.cpp


# Build system temporary files
/bazel-*
bazel-*

# Clangd cache directory
.cache/*
81 changes: 81 additions & 0 deletions build_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
from typing import Iterable
import subprocess
import sys
import shutil


def check_env_flag(name: str, default: str = '') -> bool:
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']


def bazel_options_from_env() -> Iterable[str]:
bazel_flags = []

if check_env_flag('DEBUG'):
bazel_flags.append('--config=dbg')

if check_env_flag('TPUVM_MODE'):
bazel_flags.append('--config=tpu')

gcloud_key_file = os.getenv('GCLOUD_SERVICE_KEY_FILE', default='')
# Remote cache authentication.
if gcloud_key_file:
# Temporary workaround to allow PRs from forked repo to run CI. See details at (#5259).
# TODO: Remove the check once self-hosted GHA workers are available to CPU/GPU CI.
gcloud_key_file_size = os.path.getsize(gcloud_key_file)
if gcloud_key_file_size > 1:
bazel_flags.append('--google_credentials=%s' % gcloud_key_file)
bazel_flags.append('--config=remote_cache')
else:
if check_env_flag('BAZEL_REMOTE_CACHE'):
bazel_flags.append('--config=remote_cache')

cache_silo_name = os.getenv('SILO_NAME', default='dev')
if cache_silo_name:
bazel_flags.append('--remote_default_exec_properties=cache-silo-key=%s' %
cache_silo_name)

if check_env_flag('BUILD_CPP_TESTS', default='0'):
bazel_flags.append('//test/cpp:all')
bazel_flags.append('//torch_xla/csrc/runtime:all')

bazel_jobs = os.getenv('BAZEL_JOBS', default='')
if bazel_jobs:
bazel_flags.append('--jobs=%s' % bazel_jobs)

# Build configuration.
if check_env_flag('BAZEL_VERBOSE'):
bazel_flags.append('-s')
if check_env_flag('XLA_CUDA'):
bazel_flags.append('--config=cuda')
if check_env_flag('XLA_CPU_USE_ACL'):
bazel_flags.append('--config=acl')

return bazel_flags


def bazel_build(bazel_target: str,
destination_dir: str,
options: Iterable[str] = []):
bazel_argv = [
'bazel', 'build', bazel_target,
f"--symlink_prefix={os.path.join(os.getcwd(), 'bazel-')}"
]

# Remove duplicated flags because they confuse bazel
flags = set(bazel_options_from_env() + options)
bazel_argv.extend(flags)

print(' '.join(bazel_argv), flush=True)
subprocess.check_call(bazel_argv, stdout=sys.stdout, stderr=sys.stderr)

target_path = bazel_target.replace('@xla//', 'external/xla/').replace(
'//', '').replace(':', '/')
output_path = os.path.join('bazel-bin', target_path)
output_filename = os.path.basename(output_path)

if not os.path.exists(destination_dir):
os.makedirs(destination_dir)

shutil.copyfile(output_path, os.path.join(destination_dir, output_filename))
9 changes: 2 additions & 7 deletions plugins/cpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,10 @@ repository (see `bazel build` command below).
## Building

```bash
# Build PJRT plugin
bazel build //plugins/cpu:pjrt_c_api_cpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1
# Copy to package dir
cp bazel-bin/plugins/cpu/pjrt_c_api_cpu_plugin.so plugins/cpu/torch_xla_cpu_plugin/

# Build wheel
pip wheel plugins/cpu
pip wheel plugins/cpu --no-build-isolation -v
# Or install directly
pip install plugins/cpu
pip install plugins/cpu --no-build-isolation -v
```

## Usage
Expand Down
1 change: 1 addition & 0 deletions plugins/cpu/build_util.py
2 changes: 1 addition & 1 deletion plugins/cpu/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ description = "CPU PJRT Plugin for testing only"
requires-python = ">=3.8"

[tool.setuptools.package-data]
torch_xla_cpu_plugin = ["*.so"]
torch_xla_cpu_plugin = ["lib/*.so"]

[project.entry-points."torch_xla.plugins"]
cpu = "torch_xla_cpu_plugin:CpuPlugin"
7 changes: 7 additions & 0 deletions plugins/cpu/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import build_util
import setuptools

build_util.bazel_build('//plugins/cpu:pjrt_c_api_cpu_plugin.so',
'torch_xla_cpu_plugin/lib')

setuptools.setup()
5 changes: 4 additions & 1 deletion plugins/cpu/torch_xla_cpu_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from torch_xla.experimental import plugins
from torch_xla._internal import tpu


class CpuPlugin(plugins.DevicePlugin):

def library_path(self) -> str:
return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_cpu_plugin.so')
return os.path.join(
os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so')

def physical_chip_count(self) -> int:
return 1
11 changes: 3 additions & 8 deletions plugins/cuda/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,10 @@ repository (see `bazel build` command below).
## Building

```bash
# Build PJRT plugin
bazel build @xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1 --config=cuda
# Copy to package dir
cp bazel-bin/external/xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so plugins/cuda/torch_xla_cuda_plugin

# Build wheel
pip wheel plugins/cuda
pip wheel plugins/cuda --no-build-isolation -v
# Or install directly
pip install plugins/cuda
pip install plugins/cuda --no-build-isolation -v
```

## Usage
Expand All @@ -34,7 +29,7 @@ import torch_xla.runtime as xr

# Use dynamic plugin instead of built-in CUDA support
plugins.use_dynamic_plugins()
plugins.register_plugin('CUDA', torch_xla_cuda_plugin.GpuPlugin())
plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin())
xr.set_device_type('CUDA')

print(xm.xla_device())
Expand Down
1 change: 1 addition & 0 deletions plugins/cuda/build_util.py
8 changes: 4 additions & 4 deletions plugins/cuda/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ build-backend = "setuptools.build_meta"
name = "torch_xla_cuda_plugin"
version = "0.0.1"
authors = [
{name = "Will Cromar", email = "wcromar@google.com"},
{name = "PyTorch/XLA Dev Team", email = "pytorch-xla@googlegroups.com"},
]
description = "CUDA Plugin"
description = "PyTorch/XLA CUDA Plugin"
requires-python = ">=3.8"

[tool.setuptools.package-data]
torch_xla_cuda_plugin = ["*.so"]
torch_xla_cuda_plugin = ["lib/*.so"]

[project.entry-points."torch_xla.plugins"]
gpu = "torch_xla_cuda_plugin:GpuPlugin"
cuda = "torch_xla_cuda_plugin:CudaPlugin"
7 changes: 7 additions & 0 deletions plugins/cuda/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import build_util
import setuptools

build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so',
'torch_xla_cuda_plugin/lib', ['--config=cuda'])

setuptools.setup()
8 changes: 5 additions & 3 deletions plugins/cuda/torch_xla_cuda_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from torch_xla.experimental import plugins
from torch_xla._internal import tpu

class GpuPlugin(plugins.DevicePlugin):

class CudaPlugin(plugins.DevicePlugin):

def library_path(self) -> str:
return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so')
return os.path.join(
os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so')

def physical_chip_count(self) -> int:
# TODO: default to actual device count
Expand Down
89 changes: 14 additions & 75 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,23 @@
# CXX_ABI=""
# value for cxx_abi flag; if empty, it is inferred from `torch._C`.
#
from __future__ import print_function

from setuptools import setup, find_packages, distutils, Extension, command
from setuptools.command import develop
from torch.utils.cpp_extension import BuildExtension
from setuptools.command import develop, build_ext
import posixpath
import contextlib
import distutils.ccompiler
import distutils.command.clean
import glob
import inspect
import multiprocessing
import multiprocessing.pool
import os
import platform
import re
import requests
import shutil
import subprocess
import sys
import tempfile
import torch
import zipfile

import build_util
import torch

base_dir = os.path.dirname(os.path.abspath(__file__))

_libtpu_version = '0.1.dev20240118'
Expand All @@ -82,10 +75,6 @@ def _get_build_mode():
return sys.argv[i]


def _check_env_flag(name, default=''):
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']


def get_git_head_sha(base_dir):
xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=base_dir).decode('ascii').strip()
Expand All @@ -101,7 +90,7 @@ def get_git_head_sha(base_dir):

def get_build_version(xla_git_sha):
version = os.getenv('TORCH_XLA_VERSION', '2.2.0')
if _check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'):
if build_util.check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'):
try:
version += '+git' + xla_git_sha[:7]
except Exception:
Expand Down Expand Up @@ -135,7 +124,7 @@ def maybe_bundle_libtpu(base_dir):
with contextlib.suppress(FileNotFoundError):
os.remove(libtpu_path)

if not _check_env_flag('BUNDLE_LIBTPU', '0'):
if not build_util.check_env_flag('BUNDLE_LIBTPU', '0'):
return

try:
Expand Down Expand Up @@ -201,20 +190,6 @@ def run(self):
# Copy libtpu.so into torch_xla/lib
maybe_bundle_libtpu(base_dir)

DEBUG = _check_env_flag('DEBUG')
IS_DARWIN = (platform.system() == 'Darwin')
IS_WINDOWS = sys.platform.startswith('win')
IS_LINUX = (platform.system() == 'Linux')
GCLOUD_KEY_FILE = os.getenv('GCLOUD_SERVICE_KEY_FILE', default='')
CACHE_SILO_NAME = os.getenv('SILO_NAME', default='dev')
BAZEL_JOBS = os.getenv('BAZEL_JOBS', default='')

extra_compile_args = []
cxx_abi = os.getenv(
'CXX_ABI', default='') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None)
if cxx_abi is not None:
extra_compile_args.append(f'-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}')


class BazelExtension(Extension):
"""A C/C++ extension that is defined as a Bazel BUILD target."""
Expand All @@ -230,7 +205,7 @@ def __init__(self, bazel_target):
Extension.__init__(self, ext_name, sources=[])


class BuildBazelExtension(command.build_ext.build_ext):
class BuildBazelExtension(build_ext.build_ext):
"""A command that runs Bazel to build a C/C++ extension."""

def run(self):
Expand All @@ -246,49 +221,13 @@ def bazel_build(self, ext):
'bazel', 'build', ext.bazel_target,
f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}"
]
for opt in extra_compile_args:
bazel_argv.append("--cxxopt={}".format(opt))

# Debug build.
if DEBUG:
bazel_argv.append('--config=dbg')

if _check_env_flag('TPUVM_MODE'):
bazel_argv.append('--config=tpu')

# Remote cache authentication.
if GCLOUD_KEY_FILE:
# Temporary workaround to allow PRs from forked repo to run CI. See details at (#5259).
# TODO: Remove the check once self-hosted GHA workers are available to CPU/GPU CI.
gclout_key_file_size = os.path.getsize(GCLOUD_KEY_FILE)
if gclout_key_file_size > 1:
bazel_argv.append('--google_credentials=%s' % GCLOUD_KEY_FILE)
bazel_argv.append('--config=remote_cache')
else:
if _check_env_flag('BAZEL_REMOTE_CACHE'):
bazel_argv.append('--config=remote_cache')
if CACHE_SILO_NAME:
bazel_argv.append('--remote_default_exec_properties=cache-silo-key=%s' %
CACHE_SILO_NAME)

if _check_env_flag('BUILD_CPP_TESTS', default='0'):
bazel_argv.append('//test/cpp:all')
bazel_argv.append('//torch_xla/csrc/runtime:all')

if BAZEL_JOBS:
bazel_argv.append('--jobs=%s' % BAZEL_JOBS)

# Build configuration.
if _check_env_flag('BAZEL_VERBOSE'):
bazel_argv.append('-s')
if _check_env_flag('XLA_CUDA'):
bazel_argv.append('--config=cuda')
if _check_env_flag('XLA_CPU_USE_ACL'):
bazel_argv.append('--config=acl')

if IS_WINDOWS:
for library_dir in self.library_dirs:
bazel_argv.append('--linkopt=/LIBPATH:' + library_dir)

cxx_abi = os.getenv('CXX_ABI') or getattr(torch._C,
'_GLIBCXX_USE_CXX11_ABI', None)
if cxx_abi is not None:
bazel_argv.append(f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}')

bazel_argv.extend(build_util.bazel_options_from_env())

self.spawn(bazel_argv)

Expand Down

0 comments on commit 61bd1d2

Please sign in to comment.