-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CPU PJRT plugin for testing (#6253)
- Loading branch information
1 parent
83b7571
commit 2f1334d
Showing
9 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Backport of https://github.com/openxla/xla/pull/8276 | ||
# Remove with next XLA pin update. | ||
diff --git a/xla/pjrt/cpu/cpu_client.h b/xla/pjrt/cpu/cpu_client.h | ||
index a350ce7d8..f79b215b7 100644 | ||
--- a/xla/pjrt/cpu/cpu_client.h | ||
+++ b/xla/pjrt/cpu/cpu_client.h | ||
@@ -454,6 +454,10 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { | ||
return Unimplemented("Fingerprinting executable is not supported."); | ||
} | ||
|
||
+ StatusOr<CompileOptions> GetCompileOptions() const override { | ||
+ return compile_options_; | ||
+ } | ||
+ | ||
private: | ||
friend class TfrtCpuClient; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
load( | ||
"@xla//xla:xla.bzl", | ||
"xla_cc_binary", | ||
) | ||
|
||
cc_library( | ||
name = "test_cpu_plugin", | ||
srcs = ["test_cpu_plugin.cc"], | ||
hdrs = ["test_cpu_plugin.h"], | ||
visibility = ["//visibility:public"], | ||
deps = [ | ||
"@xla//xla/pjrt/c:pjrt_c_api_cpu_internal", | ||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs", | ||
], | ||
) | ||
|
||
xla_cc_binary( | ||
name = "pjrt_c_api_cpu_plugin.so", | ||
linkopts = [ | ||
"-Wl,--version-script,$(location :pjrt_c_api_cpu_version_script.lds)", | ||
"-Wl,--no-undefined", | ||
], | ||
linkshared = True, | ||
deps = [ | ||
":pjrt_c_api_cpu_version_script.lds", | ||
":test_cpu_plugin", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# CPU PJRT Plugin (testing) | ||
|
||
This directory contains an experimental implementation of the PJRT CPU client as | ||
a plugin. This plugin is for testing only and is not officially supported. Use | ||
`PJRT_DEVICE=CPU` with any PyTorch/XLA installation to use built-in CPU support. | ||
|
||
The actual implementation of the PJRT C API lives in the main OpenXLA | ||
repository (see `bazel build` command below). | ||
|
||
## Building | ||
|
||
```bash | ||
# Build PJRT plugin | ||
bazel build //plugins/cpu:pjrt_c_api_cpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1 | ||
# Copy to package dir | ||
cp bazel-bin/plugins/cpu/pjrt_c_api_cpu_plugin.so plugins/cpu/torch_xla_cpu_plugin/ | ||
|
||
# Build wheel | ||
pip wheel plugins/cpu | ||
# Or install directly | ||
pip install plugins/cpu | ||
``` | ||
|
||
## Usage | ||
|
||
```python | ||
import os | ||
|
||
# Log device type | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' | ||
os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5' | ||
|
||
from torch_xla.experimental import plugins | ||
import torch_xla_cpu_plugin | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.runtime as xr | ||
|
||
# Use dynamic plugin instead of built-in CPU support | ||
plugins.use_dynamic_plugins() | ||
plugins.register_plugin('CPU', torch_xla_cpu_plugin.CpuPlugin()) | ||
xr.set_device_type('CPU') | ||
|
||
print(xm.xla_device()) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Only symbols in the global section are available to other frameworks. | ||
VERS_1.0 { | ||
global: | ||
extern "C" { | ||
GetPjrtApi; | ||
}; | ||
|
||
local: | ||
*; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["setuptools"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
name = "torch_xla_cpu_plugin" | ||
version = "0.0.1" | ||
authors = [ | ||
{name = "Will Cromar", email = "[email protected]"}, | ||
] | ||
description = "CPU PJRT Plugin for testing only" | ||
requires-python = ">=3.8" | ||
|
||
[tool.setuptools.package-data] | ||
torch_xla_cpu_plugin = ["*.so"] | ||
|
||
[project.entry-points."torch_xla.plugins"] | ||
cpu = "torch_xla_cpu_plugin:CpuPlugin" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#include "plugins/cpu/test_cpu_plugin.h" | ||
|
||
#include <iostream> | ||
|
||
#include "xla/pjrt/c/pjrt_c_api.h" | ||
#include "xla/pjrt/c/pjrt_c_api_cpu_internal.h" | ||
|
||
// Use `test` as the platform name instead of `cpu` so torch_xla treats this | ||
// as an unknown device. | ||
PJRT_Error* test_platform_name(PJRT_Client_PlatformName_Args* args) { | ||
static const std::string platform_name = "test"; | ||
args->platform_name = platform_name.c_str(); | ||
args->platform_name_size = platform_name.size(); | ||
return nullptr; | ||
} | ||
|
||
const PJRT_Api* GetPjrtApi() { | ||
// HACK: The CPU client is created as a constexpr, so const-casting is | ||
// undefined behavior. Make a non-const copy of the struct so we can override | ||
// methods. Don't do this for a real plugin. | ||
static PJRT_Api pjrt_api = *pjrt::cpu_plugin::GetCpuPjrtApi(); | ||
pjrt_api.PJRT_Client_PlatformName = test_platform_name; | ||
|
||
return &pjrt_api; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#ifndef XLA_PJRT_C_PJRT_C_API_CPU_H_ | ||
#define XLA_PJRT_C_PJRT_C_API_CPU_H_ | ||
|
||
#include "xla/pjrt/c/pjrt_c_api.h" | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
const PJRT_Api* GetPjrtApi(); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
||
#endif // XLA_PJRT_C_PJRT_C_API_CPU_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import os | ||
from torch_xla.experimental import plugins | ||
from torch_xla._internal import tpu | ||
|
||
class CpuPlugin(plugins.DevicePlugin): | ||
def library_path(self) -> str: | ||
return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_cpu_plugin.so') | ||
|
||
def physical_chip_count(self) -> int: | ||
return 1 |