diff --git a/WORKSPACE b/WORKSPACE index 46f2e24c4a7..4b008b47655 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -48,6 +48,7 @@ http_archive( patches = [ "//openxla_patches:cache_urls.diff", "//openxla_patches:constexpr_return.diff", + "//openxla_patches:cpu_compile_options.diff", "//openxla_patches:gpu_compile_options.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", diff --git a/openxla_patches/cpu_compile_options.diff b/openxla_patches/cpu_compile_options.diff new file mode 100644 index 00000000000..2213528451a --- /dev/null +++ b/openxla_patches/cpu_compile_options.diff @@ -0,0 +1,17 @@ +# Backport of https://github.com/openxla/xla/pull/8276 +# Remove with next XLA pin update. +diff --git a/xla/pjrt/cpu/cpu_client.h b/xla/pjrt/cpu/cpu_client.h +index a350ce7d8..f79b215b7 100644 +--- a/xla/pjrt/cpu/cpu_client.h ++++ b/xla/pjrt/cpu/cpu_client.h +@@ -454,6 +454,10 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { + return Unimplemented("Fingerprinting executable is not supported."); + } + ++ StatusOr GetCompileOptions() const override { ++ return compile_options_; ++ } ++ + private: + friend class TfrtCpuClient; + diff --git a/plugins/cpu/BUILD b/plugins/cpu/BUILD new file mode 100644 index 00000000000..a8818d480cd --- /dev/null +++ b/plugins/cpu/BUILD @@ -0,0 +1,28 @@ +load( + "@xla//xla:xla.bzl", + "xla_cc_binary", +) + +cc_library( + name = "test_cpu_plugin", + srcs = ["test_cpu_plugin.cc"], + hdrs = ["test_cpu_plugin.h"], + visibility = ["//visibility:public"], + deps = [ + "@xla//xla/pjrt/c:pjrt_c_api_cpu_internal", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + ], +) + +xla_cc_binary( + name = "pjrt_c_api_cpu_plugin.so", + linkopts = [ + "-Wl,--version-script,$(location :pjrt_c_api_cpu_version_script.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + deps = [ + ":pjrt_c_api_cpu_version_script.lds", + ":test_cpu_plugin", + ], +) diff --git a/plugins/cpu/README.md b/plugins/cpu/README.md new file mode 100644 index 00000000000..b3fedc24d8d --- /dev/null +++ b/plugins/cpu/README.md @@ -0,0 +1,44 @@ +# CPU PJRT Plugin (testing) + +This directory contains an experimental implementation of the PJRT CPU client as +a plugin. This plugin is for testing only and is not officially supported. Use +`PJRT_DEVICE=CPU` with any PyTorch/XLA installation to use built-in CPU support. + +The actual implementation of the PJRT C API lives in the main OpenXLA +repository (see `bazel build` command below). + +## Building + +```bash +# Build PJRT plugin +bazel build //plugins/cpu:pjrt_c_api_cpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1 +# Copy to package dir +cp bazel-bin/plugins/cpu/pjrt_c_api_cpu_plugin.so plugins/cpu/torch_xla_cpu_plugin/ + +# Build wheel +pip wheel plugins/cpu +# Or install directly +pip install plugins/cpu +``` + +## Usage + +```python +import os + +# Log device type +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' +os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5' + +from torch_xla.experimental import plugins +import torch_xla_cpu_plugin +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +# Use dynamic plugin instead of built-in CPU support +plugins.use_dynamic_plugins() +plugins.register_plugin('CPU', torch_xla_cpu_plugin.CpuPlugin()) +xr.set_device_type('CPU') + +print(xm.xla_device()) +``` diff --git a/plugins/cpu/pjrt_c_api_cpu_version_script.lds b/plugins/cpu/pjrt_c_api_cpu_version_script.lds new file mode 100644 index 00000000000..f371ee430c9 --- /dev/null +++ b/plugins/cpu/pjrt_c_api_cpu_version_script.lds @@ -0,0 +1,10 @@ +# Only symbols in the global section are available to other frameworks. +VERS_1.0 { + global: + extern "C" { + GetPjrtApi; + }; + + local: + *; +}; diff --git a/plugins/cpu/pyproject.toml b/plugins/cpu/pyproject.toml new file mode 100644 index 00000000000..9359a0dc7b5 --- /dev/null +++ b/plugins/cpu/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "torch_xla_cpu_plugin" +version = "0.0.1" +authors = [ + {name = "Will Cromar", email = "wcromar@google.com"}, +] +description = "CPU PJRT Plugin for testing only" +requires-python = ">=3.8" + +[tool.setuptools.package-data] +torch_xla_cpu_plugin = ["*.so"] + +[project.entry-points."torch_xla.plugins"] +cpu = "torch_xla_cpu_plugin:CpuPlugin" diff --git a/plugins/cpu/test_cpu_plugin.cc b/plugins/cpu/test_cpu_plugin.cc new file mode 100644 index 00000000000..e4996a2007e --- /dev/null +++ b/plugins/cpu/test_cpu_plugin.cc @@ -0,0 +1,25 @@ +#include "plugins/cpu/test_cpu_plugin.h" + +#include + +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_cpu_internal.h" + +// Use `test` as the platform name instead of `cpu` so torch_xla treats this +// as an unknown device. +PJRT_Error* test_platform_name(PJRT_Client_PlatformName_Args* args) { + static const std::string platform_name = "test"; + args->platform_name = platform_name.c_str(); + args->platform_name_size = platform_name.size(); + return nullptr; +} + +const PJRT_Api* GetPjrtApi() { + // HACK: The CPU client is created as a constexpr, so const-casting is + // undefined behavior. Make a non-const copy of the struct so we can override + // methods. Don't do this for a real plugin. + static PJRT_Api pjrt_api = *pjrt::cpu_plugin::GetCpuPjrtApi(); + pjrt_api.PJRT_Client_PlatformName = test_platform_name; + + return &pjrt_api; +} diff --git a/plugins/cpu/test_cpu_plugin.h b/plugins/cpu/test_cpu_plugin.h new file mode 100644 index 00000000000..f8f8990d6ff --- /dev/null +++ b/plugins/cpu/test_cpu_plugin.h @@ -0,0 +1,16 @@ +#ifndef XLA_PJRT_C_PJRT_C_API_CPU_H_ +#define XLA_PJRT_C_PJRT_C_API_CPU_H_ + +#include "xla/pjrt/c/pjrt_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +const PJRT_Api* GetPjrtApi(); + +#ifdef __cplusplus +} +#endif + +#endif // XLA_PJRT_C_PJRT_C_API_CPU_H_ diff --git a/plugins/cpu/torch_xla_cpu_plugin/__init__.py b/plugins/cpu/torch_xla_cpu_plugin/__init__.py new file mode 100644 index 00000000000..da7a3234267 --- /dev/null +++ b/plugins/cpu/torch_xla_cpu_plugin/__init__.py @@ -0,0 +1,10 @@ +import os +from torch_xla.experimental import plugins +from torch_xla._internal import tpu + +class CpuPlugin(plugins.DevicePlugin): + def library_path(self) -> str: + return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_cpu_plugin.so') + + def physical_chip_count(self) -> int: + return 1