From fc408f2335c62d2e3678774ac2dd903836850103 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 1 Nov 2024 22:57:33 +0000 Subject: [PATCH] Add functions to emit custom call to place a buffer to host and device. This is used for host-offloading. example code of what jax emits: ```python def policy(prim, *avals, **params) -> Offloadable: return Offloadable(src='device', dst='pinned_host') @functools.partial(jax.remat, policy=policy) def f(x): x = jnp.sin(x) x = jnp.sin(x) return jnp.sum(x) ``` becomes: ```mlir module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = "default"}) -> (tensor<16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = stablehlo.sine %arg0 : tensor<16xf32> %1 = stablehlo.cosine %arg0 : tensor<16xf32> %2 = stablehlo.custom_call @annotate_device_placement(%1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32> %3 = stablehlo.cosine %0 : tensor<16xf32> %4 = stablehlo.custom_call @annotate_device_placement(%3) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32> %cst = stablehlo.constant dense<1.000000e+00> : tensor %5:3 = stablehlo.optimization_barrier %2, %4, %cst : tensor<16xf32>, tensor<16xf32>, tensor %6 = stablehlo.custom_call @annotate_device_placement(%5#0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32> %7 = stablehlo.custom_call @annotate_device_placement(%5#1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32> %8 = stablehlo.broadcast_in_dim %5#2, dims = [] : (tensor) -> tensor<16xf32> %9 = stablehlo.multiply %8, %7 : tensor<16xf32> %10 = stablehlo.multiply %9, %6 : tensor<16xf32> return %10 : tensor<16xf32> } } ``` --- test/stablehlo/test_stablehlo_custom_call.py | 19 +++++++++++++- torch_xla/csrc/init_python_bindings.cpp | 8 +++--- torch_xla/csrc/ops/custom_call.cpp | 26 +++++++++++++++---- torch_xla/csrc/ops/custom_call.h | 8 ++++++ torch_xla/csrc/tensor_methods.cpp | 5 ++-- torch_xla/csrc/tensor_methods.h | 3 ++- .../experimental/stablehlo_custom_call.py | 18 +++++++++++-- 7 files changed, 73 insertions(+), 14 deletions(-) diff --git a/test/stablehlo/test_stablehlo_custom_call.py b/test/stablehlo/test_stablehlo_custom_call.py index 7291608e506..826d138d856 100644 --- a/test/stablehlo/test_stablehlo_custom_call.py +++ b/test/stablehlo/test_stablehlo_custom_call.py @@ -6,7 +6,8 @@ import torch_xla.core.xla_model as xm import torch_xla.experimental.stablehlo_custom_call from torch.library import Library, impl, impl_abstract -from torch_xla.experimental.stablehlo_custom_call import stablehlo_custom_call +from torch_xla.experimental.stablehlo_custom_call import ( + stablehlo_custom_call, place_to_host, place_to_device) from torch_xla.stablehlo import (StableHLOExportOptions, exported_program_to_stablehlo) @@ -115,6 +116,22 @@ def forward(self, x): # self.assertTrue("api_version = 1" in shlo_text) + def test_place_to_host_device(self): + dev = xm.xla_device() + a = torch.ones(10, device=dev) + b = place_to_host(a) + shlo_text = xm.get_stablehlo([b]) + self.assertTrue("has_side_effect = true" in shlo_text) + self.assertTrue("mhlo.frontend_attributes = {_xla_buffer_placement = \"pinned_host\"}}" in shlo_text) + + a = torch.ones(10, device=dev) + b = place_to_device(a) + shlo_text = xm.get_stablehlo([b]) + self.assertTrue("has_side_effect = true" in shlo_text) + self.assertTrue("mhlo.frontend_attributes = {_xla_buffer_placement = \"device\"}}" in shlo_text) + + + if __name__ == "__main__": test = unittest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 77fb9f2f8ab..2ac991bc814 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2555,8 +2555,9 @@ void InitXlaModuleBindings(py::module m) { [](const std::vector& inputs, const std::string& target, const std::vector>& output_shapes, const std::vector& output_dtypes, bool has_side_effect, - const std::string& backend_config, - const int api_version) -> std::vector { + const std::string& backend_config, const int api_version, + const std::unordered_map& + frontend_attributes) -> std::vector { std::vector dtypes; dtypes.reserve(output_dtypes.size()); for (auto& dtype : output_dtypes) { @@ -2566,7 +2567,8 @@ void InitXlaModuleBindings(py::module m) { auto xtensors = tensor_methods::custom_call( bridge::GetXlaTensors(inputs), target, output_shapes, dtypes, - has_side_effect, backend_config, api_version); + has_side_effect, backend_config, api_version, + frontend_attributes); return bridge::AtenFromXlaTensors(std::move(xtensors)); }); m.def("_xla_tpu_custom_call", diff --git a/torch_xla/csrc/ops/custom_call.cpp b/torch_xla/csrc/ops/custom_call.cpp index 9d60ce89378..3d28df12f82 100644 --- a/torch_xla/csrc/ops/custom_call.cpp +++ b/torch_xla/csrc/ops/custom_call.cpp @@ -8,17 +8,27 @@ namespace torch_xla { -CustomCall::CustomCall(torch::lazy::OpList inputs, - const std::string& call_target, xla::Shape output_shape, - bool has_side_effect, const std::string& backend_config, - const int api_version) +CustomCall::CustomCall( + torch::lazy::OpList inputs, const std::string& call_target, + xla::Shape output_shape, bool has_side_effect, + const std::string& backend_config, const int api_version, + const std::unordered_map& frontend_attributes) : XlaNode(xla_custom_call, inputs, std::move(output_shape), /*num_outputs=*/output_shape.tuple_shapes_size(), torch::lazy::MHash(call_target)), call_target_(call_target), has_side_effect_(has_side_effect), backend_config_(backend_config), - api_version_(api_version) {} + api_version_(api_version), + frontend_attributes_(frontend_attributes) {} + +CustomCall::CustomCall(torch::lazy::OpList inputs, + const std::string& call_target, xla::Shape output_shape, + bool has_side_effect, const std::string& backend_config, + const int api_version) + : CustomCall(inputs, call_target, output_shape, has_side_effect, + backend_config, api_version, + std::unordered_map()) {} torch::lazy::NodePtr CustomCall::Clone(torch::lazy::OpList operands) const { return torch_xla::MakeNode(operands, call_target_, @@ -38,6 +48,12 @@ XlaOpVector CustomCall::Lower(LoweringContext* loctx) const { output_shape = output_shape.tuple_shapes(0); } XLA_CHECK(api_version_ >= 0 && api_version_ < 5); + + xla::FrontendAttributes feattr; + feattr.mutable_map()->insert(frontend_attributes_.begin(), + frontend_attributes_.end()); + xla::XlaScopedFrontendAttributesAssignment feattr_assign(inputs[0].builder(), + feattr); xla::XlaOp output = xla::CustomCall( inputs[0].builder(), call_target_, inputs, output_shape, /*opaque=*/backend_config_, diff --git a/torch_xla/csrc/ops/custom_call.h b/torch_xla/csrc/ops/custom_call.h index 69bb613d4b6..172e5f127bd 100644 --- a/torch_xla/csrc/ops/custom_call.h +++ b/torch_xla/csrc/ops/custom_call.h @@ -1,6 +1,8 @@ #ifndef XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_ #define XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_ +#include + #include "torch_xla/csrc/ir.h" namespace torch_xla { @@ -10,6 +12,11 @@ class CustomCall : public XlaNode { CustomCall(torch::lazy::OpList inputs, const std::string& call_target, xla::Shape output_shape, bool has_side_effect, const std::string& backend_config, const int api_version); + CustomCall( + torch::lazy::OpList inputs, const std::string& call_target, + xla::Shape output_shape, bool has_side_effect, + const std::string& backend_config, const int api_version, + const std::unordered_map& frontend_attributes); std::string ToString() const override; @@ -22,6 +29,7 @@ class CustomCall : public XlaNode { bool has_side_effect_; std::string backend_config_; int api_version_; + std::unordered_map frontend_attributes_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index bb71d09173f..9c34f84c44a 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -564,7 +564,8 @@ std::vector custom_call( const std::vector& inputs, const std::string& target, const std::vector>& output_shapes, const std::vector& output_dtypes, bool has_side_effect, - const std::string& backend_config, const int api_version) { + const std::string& backend_config, const int api_version, + const std::unordered_map& frontend_attributes) { XLA_CHECK(inputs.size() > 0) << "inputs are empty"; std::vector values; @@ -584,7 +585,7 @@ std::vector custom_call( auto node = torch_xla::MakeNode( values, target, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), - has_side_effect, backend_config, api_version); + has_side_effect, backend_config, api_version, frontend_attributes); std::vector outputs; outputs.reserve(output_shapes.size()); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 8bd401b168e..c00f4dd52e6 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -92,7 +92,8 @@ std::vector custom_call( const std::vector& inputs, const std::string& target, const std::vector>& output_shapes, const std::vector& output_dtypes, bool has_side_effect, - const std::string& backend_config, const int api_version); + const std::string& backend_config, const int api_version, + const std::unordered_map& frontend_attributes); void custom_sharding_( const XLATensorPtr& input, diff --git a/torch_xla/experimental/stablehlo_custom_call.py b/torch_xla/experimental/stablehlo_custom_call.py index e729d0b7791..b39993221b0 100644 --- a/torch_xla/experimental/stablehlo_custom_call.py +++ b/torch_xla/experimental/stablehlo_custom_call.py @@ -10,10 +10,13 @@ def stablehlo_custom_call(args, output_dtypes, has_side_effect=False, backend_config="", - api_version=0): + api_version=0, + frontend_attributes=None): + frontend_attributes = frontend_attributes or {} res = torch_xla._XLAC._xla_custom_call(args, call_target, output_shapes, output_dtypes, has_side_effect, - backend_config, api_version) + backend_config, api_version, + frontend_attributes) if len(output_shapes) == 1: return res[0] return res @@ -29,3 +32,14 @@ def extract_custom_call_outputs_shape_dtype(n: torch.fx.Node): assert None not in output_shape_dtype output_shape, output_dtype = zip(*output_shape_dtype) return output_shape, output_dtype + + +def place_to_host(a: torch.Tensor): + return stablehlo_custom_call([a], "annotate_device_placement", + [a.shape],[a.dtype], has_side_effect=True, + frontend_attributes={"_xla_buffer_placement": "pinned_host"}) + +def place_to_device(a: torch.Tensor): + return stablehlo_custom_call([a], "annotate_device_placement", + [a.shape],[a.dtype], has_side_effect=True, + frontend_attributes={"_xla_buffer_placement": "device"})