Skip to content

Commit

Permalink
Add CPU PJRT plugin for testing (#6253)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored Jan 10, 2024
1 parent 83b7571 commit 2f1334d
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 0 deletions.
1 change: 1 addition & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ http_archive(
patches = [
"//openxla_patches:cache_urls.diff",
"//openxla_patches:constexpr_return.diff",
"//openxla_patches:cpu_compile_options.diff",
"//openxla_patches:gpu_compile_options.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
Expand Down
17 changes: 17 additions & 0 deletions openxla_patches/cpu_compile_options.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Backport of https://github.com/openxla/xla/pull/8276
# Remove with next XLA pin update.
diff --git a/xla/pjrt/cpu/cpu_client.h b/xla/pjrt/cpu/cpu_client.h
index a350ce7d8..f79b215b7 100644
--- a/xla/pjrt/cpu/cpu_client.h
+++ b/xla/pjrt/cpu/cpu_client.h
@@ -454,6 +454,10 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable {
return Unimplemented("Fingerprinting executable is not supported.");
}

+ StatusOr<CompileOptions> GetCompileOptions() const override {
+ return compile_options_;
+ }
+
private:
friend class TfrtCpuClient;

28 changes: 28 additions & 0 deletions plugins/cpu/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
load(
"@xla//xla:xla.bzl",
"xla_cc_binary",
)

cc_library(
name = "test_cpu_plugin",
srcs = ["test_cpu_plugin.cc"],
hdrs = ["test_cpu_plugin.h"],
visibility = ["//visibility:public"],
deps = [
"@xla//xla/pjrt/c:pjrt_c_api_cpu_internal",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
],
)

xla_cc_binary(
name = "pjrt_c_api_cpu_plugin.so",
linkopts = [
"-Wl,--version-script,$(location :pjrt_c_api_cpu_version_script.lds)",
"-Wl,--no-undefined",
],
linkshared = True,
deps = [
":pjrt_c_api_cpu_version_script.lds",
":test_cpu_plugin",
],
)
44 changes: 44 additions & 0 deletions plugins/cpu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# CPU PJRT Plugin (testing)

This directory contains an experimental implementation of the PJRT CPU client as
a plugin. This plugin is for testing only and is not officially supported. Use
`PJRT_DEVICE=CPU` with any PyTorch/XLA installation to use built-in CPU support.

The actual implementation of the PJRT C API lives in the main OpenXLA
repository (see `bazel build` command below).

## Building

```bash
# Build PJRT plugin
bazel build //plugins/cpu:pjrt_c_api_cpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1
# Copy to package dir
cp bazel-bin/plugins/cpu/pjrt_c_api_cpu_plugin.so plugins/cpu/torch_xla_cpu_plugin/

# Build wheel
pip wheel plugins/cpu
# Or install directly
pip install plugins/cpu
```

## Usage

```python
import os

# Log device type
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5'

from torch_xla.experimental import plugins
import torch_xla_cpu_plugin
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

# Use dynamic plugin instead of built-in CPU support
plugins.use_dynamic_plugins()
plugins.register_plugin('CPU', torch_xla_cpu_plugin.CpuPlugin())
xr.set_device_type('CPU')

print(xm.xla_device())
```
10 changes: 10 additions & 0 deletions plugins/cpu/pjrt_c_api_cpu_version_script.lds
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Only symbols in the global section are available to other frameworks.
VERS_1.0 {
global:
extern "C" {
GetPjrtApi;
};

local:
*;
};
18 changes: 18 additions & 0 deletions plugins/cpu/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[project]
name = "torch_xla_cpu_plugin"
version = "0.0.1"
authors = [
{name = "Will Cromar", email = "[email protected]"},
]
description = "CPU PJRT Plugin for testing only"
requires-python = ">=3.8"

[tool.setuptools.package-data]
torch_xla_cpu_plugin = ["*.so"]

[project.entry-points."torch_xla.plugins"]
cpu = "torch_xla_cpu_plugin:CpuPlugin"
25 changes: 25 additions & 0 deletions plugins/cpu/test_cpu_plugin.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "plugins/cpu/test_cpu_plugin.h"

#include <iostream>

#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_cpu_internal.h"

// Use `test` as the platform name instead of `cpu` so torch_xla treats this
// as an unknown device.
PJRT_Error* test_platform_name(PJRT_Client_PlatformName_Args* args) {
static const std::string platform_name = "test";
args->platform_name = platform_name.c_str();
args->platform_name_size = platform_name.size();
return nullptr;
}

const PJRT_Api* GetPjrtApi() {
// HACK: The CPU client is created as a constexpr, so const-casting is
// undefined behavior. Make a non-const copy of the struct so we can override
// methods. Don't do this for a real plugin.
static PJRT_Api pjrt_api = *pjrt::cpu_plugin::GetCpuPjrtApi();
pjrt_api.PJRT_Client_PlatformName = test_platform_name;

return &pjrt_api;
}
16 changes: 16 additions & 0 deletions plugins/cpu/test_cpu_plugin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef XLA_PJRT_C_PJRT_C_API_CPU_H_
#define XLA_PJRT_C_PJRT_C_API_CPU_H_

#include "xla/pjrt/c/pjrt_c_api.h"

#ifdef __cplusplus
extern "C" {
#endif

const PJRT_Api* GetPjrtApi();

#ifdef __cplusplus
}
#endif

#endif // XLA_PJRT_C_PJRT_C_API_CPU_H_
10 changes: 10 additions & 0 deletions plugins/cpu/torch_xla_cpu_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os
from torch_xla.experimental import plugins
from torch_xla._internal import tpu

class CpuPlugin(plugins.DevicePlugin):
def library_path(self) -> str:
return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_cpu_plugin.so')

def physical_chip_count(self) -> int:
return 1

0 comments on commit 2f1334d

Please sign in to comment.