Skip to content

Commit

Permalink
use build_util for CPU too
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Jan 18, 2024
1 parent d8dda3f commit e22cfda
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 38 deletions.
22 changes: 22 additions & 0 deletions build_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
from typing import Iterable
import subprocess
import sys
import shutil

def check_env_flag(name: str, default: str = '') -> bool:
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
Expand Down Expand Up @@ -49,3 +52,22 @@ def bazel_options_from_env() -> Iterable[str]:
bazel_flags.append('--config=acl')

return bazel_flags

def bazel_build(bazel_target: str, destination_dir: str, options: Iterable[str] = []):
bazel_argv = ['bazel', 'build', bazel_target, f"--symlink_prefix={os.path.join(os.getcwd(), 'bazel-')}"]

# Remove duplicated flags because they confuse bazel
flags = set(bazel_options_from_env() + options)
bazel_argv.extend(flags)

print(' '.join(bazel_argv), flush=True)
subprocess.check_call(bazel_argv, stdout=sys.stdout, stderr=sys.stderr)

target_path = bazel_target.replace('@xla//', 'external/xla/').replace('//', '').replace(':', '/')
output_path = os.path.join('bazel-bin', target_path)
output_filename = os.path.basename(output_path)

if not os.path.exists(destination_dir):
os.makedirs(destination_dir)

shutil.copyfile(output_path, os.path.join(destination_dir, output_filename))
9 changes: 2 additions & 7 deletions plugins/cpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,10 @@ repository (see `bazel build` command below).
## Building

```bash
# Build PJRT plugin
bazel build //plugins/cpu:pjrt_c_api_cpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1
# Copy to package dir
cp bazel-bin/plugins/cpu/pjrt_c_api_cpu_plugin.so plugins/cpu/torch_xla_cpu_plugin/

# Build wheel
pip wheel plugins/cpu
pip wheel plugins/cpu --no-build-isolation -v
# Or install directly
pip install plugins/cpu
pip install plugins/cpu --no-build-isolation -v
```

## Usage
Expand Down
1 change: 1 addition & 0 deletions plugins/cpu/build_util.py
2 changes: 1 addition & 1 deletion plugins/cpu/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ description = "CPU PJRT Plugin for testing only"
requires-python = ">=3.8"

[tool.setuptools.package-data]
torch_xla_cpu_plugin = ["*.so"]
torch_xla_cpu_plugin = ["lib/*.so"]

[project.entry-points."torch_xla.plugins"]
cpu = "torch_xla_cpu_plugin:CpuPlugin"
6 changes: 6 additions & 0 deletions plugins/cpu/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import build_util
import setuptools

build_util.bazel_build('//plugins/cpu:pjrt_c_api_cpu_plugin.so', 'torch_xla_cpu_plugin/lib')

setuptools.setup()
2 changes: 1 addition & 1 deletion plugins/cpu/torch_xla_cpu_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class CpuPlugin(plugins.DevicePlugin):
def library_path(self) -> str:
return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_cpu_plugin.so')
return os.path.join(os.path.dirname(__file__), 'lib/pjrt_c_api_cpu_plugin.so')

def physical_chip_count(self) -> int:
return 1
4 changes: 2 additions & 2 deletions plugins/cuda/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ repository (see `bazel build` command below).

```bash
# Build wheel
pip wheel plugins/cuda
pip wheel plugins/cuda --no-build-isolation -v
# Or install directly
pip install plugins/cuda
pip install plugins/cuda --no-build-isolation -v
```

## Usage
Expand Down
29 changes: 2 additions & 27 deletions plugins/cuda/setup.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,6 @@
import subprocess
import sys
from typing import Iterable
import setuptools
import shutil
import os

import build_util
import setuptools

def _bazel_build(bazel_target: str, destination_path: str, options: Iterable[str] = []):
bazel_argv = ['bazel', 'build', bazel_target, f"--symlink_prefix={os.path.join(os.getcwd(), 'bazel-')}"]

# Remove duplicated flags because they confuse bazel
flags = set(build_util.bazel_options_from_env() + options)
bazel_argv.extend(flags)

print(' '.join(bazel_argv), flush=True)
subprocess.check_call(bazel_argv, stdout=sys.stdout, stderr=sys.stderr)

target_path = bazel_target.replace('@xla//', 'external/xla/').replace(':', '/')
output_path = os.path.join('bazel-bin', target_path)
output_filename = os.path.basename(output_path)

if not os.path.exists(destination_path):
os.makedirs(destination_path)

shutil.copyfile(output_path, os.path.join(destination_path, output_filename))

_bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so', 'torch_xla_cuda_plugin/lib', ['--config=cuda'])
build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so', 'torch_xla_cuda_plugin/lib', ['--config=cuda'])

setuptools.setup()

0 comments on commit e22cfda

Please sign in to comment.