diff --git a/test/stablehlo/test_stablehlo_custom_call.py b/test/stablehlo/test_stablehlo_custom_call.py index 7291608e506..5923221b4da 100644 --- a/test/stablehlo/test_stablehlo_custom_call.py +++ b/test/stablehlo/test_stablehlo_custom_call.py @@ -6,7 +6,9 @@ import torch_xla.core.xla_model as xm import torch_xla.experimental.stablehlo_custom_call from torch.library import Library, impl, impl_abstract -from torch_xla.experimental.stablehlo_custom_call import stablehlo_custom_call +from torch_xla.experimental.stablehlo_custom_call import (stablehlo_custom_call, + place_to_host, + place_to_device) from torch_xla.stablehlo import (StableHLOExportOptions, exported_program_to_stablehlo) @@ -114,6 +116,24 @@ def forward(self, x): # TODO: api version lost during conversion, or not shown in txt format. # self.assertTrue("api_version = 1" in shlo_text) + def test_place_to_host_device(self): + dev = xm.xla_device() + a = torch.ones(10, device=dev) + b = place_to_host(a) + shlo_text = xm.get_stablehlo([b]) + self.assertTrue("has_side_effect = true" in shlo_text) + self.assertTrue( + "mhlo.frontend_attributes = {_xla_buffer_placement = \"pinned_host\"}}" + in shlo_text) + + a = torch.ones(10, device=dev) + b = place_to_device(a) + shlo_text = xm.get_stablehlo([b]) + self.assertTrue("has_side_effect = true" in shlo_text) + self.assertTrue( + "mhlo.frontend_attributes = {_xla_buffer_placement = \"device\"}}" in + shlo_text) + if __name__ == "__main__": diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 77fb9f2f8ab..2ac991bc814 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2555,8 +2555,9 @@ void InitXlaModuleBindings(py::module m) { [](const std::vector& inputs, const std::string& target, const std::vector>& output_shapes, const std::vector& output_dtypes, bool has_side_effect, - const std::string& backend_config, - const int api_version) -> std::vector { + const std::string& backend_config, const int api_version, + const std::unordered_map& + frontend_attributes) -> std::vector { std::vector dtypes; dtypes.reserve(output_dtypes.size()); for (auto& dtype : output_dtypes) { @@ -2566,7 +2567,8 @@ void InitXlaModuleBindings(py::module m) { auto xtensors = tensor_methods::custom_call( bridge::GetXlaTensors(inputs), target, output_shapes, dtypes, - has_side_effect, backend_config, api_version); + has_side_effect, backend_config, api_version, + frontend_attributes); return bridge::AtenFromXlaTensors(std::move(xtensors)); }); m.def("_xla_tpu_custom_call", diff --git a/torch_xla/csrc/ops/custom_call.cpp b/torch_xla/csrc/ops/custom_call.cpp index 9d60ce89378..3d28df12f82 100644 --- a/torch_xla/csrc/ops/custom_call.cpp +++ b/torch_xla/csrc/ops/custom_call.cpp @@ -8,17 +8,27 @@ namespace torch_xla { -CustomCall::CustomCall(torch::lazy::OpList inputs, - const std::string& call_target, xla::Shape output_shape, - bool has_side_effect, const std::string& backend_config, - const int api_version) +CustomCall::CustomCall( + torch::lazy::OpList inputs, const std::string& call_target, + xla::Shape output_shape, bool has_side_effect, + const std::string& backend_config, const int api_version, + const std::unordered_map& frontend_attributes) : XlaNode(xla_custom_call, inputs, std::move(output_shape), /*num_outputs=*/output_shape.tuple_shapes_size(), torch::lazy::MHash(call_target)), call_target_(call_target), has_side_effect_(has_side_effect), backend_config_(backend_config), - api_version_(api_version) {} + api_version_(api_version), + frontend_attributes_(frontend_attributes) {} + +CustomCall::CustomCall(torch::lazy::OpList inputs, + const std::string& call_target, xla::Shape output_shape, + bool has_side_effect, const std::string& backend_config, + const int api_version) + : CustomCall(inputs, call_target, output_shape, has_side_effect, + backend_config, api_version, + std::unordered_map()) {} torch::lazy::NodePtr CustomCall::Clone(torch::lazy::OpList operands) const { return torch_xla::MakeNode(operands, call_target_, @@ -38,6 +48,12 @@ XlaOpVector CustomCall::Lower(LoweringContext* loctx) const { output_shape = output_shape.tuple_shapes(0); } XLA_CHECK(api_version_ >= 0 && api_version_ < 5); + + xla::FrontendAttributes feattr; + feattr.mutable_map()->insert(frontend_attributes_.begin(), + frontend_attributes_.end()); + xla::XlaScopedFrontendAttributesAssignment feattr_assign(inputs[0].builder(), + feattr); xla::XlaOp output = xla::CustomCall( inputs[0].builder(), call_target_, inputs, output_shape, /*opaque=*/backend_config_, diff --git a/torch_xla/csrc/ops/custom_call.h b/torch_xla/csrc/ops/custom_call.h index 69bb613d4b6..172e5f127bd 100644 --- a/torch_xla/csrc/ops/custom_call.h +++ b/torch_xla/csrc/ops/custom_call.h @@ -1,6 +1,8 @@ #ifndef XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_ #define XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_ +#include + #include "torch_xla/csrc/ir.h" namespace torch_xla { @@ -10,6 +12,11 @@ class CustomCall : public XlaNode { CustomCall(torch::lazy::OpList inputs, const std::string& call_target, xla::Shape output_shape, bool has_side_effect, const std::string& backend_config, const int api_version); + CustomCall( + torch::lazy::OpList inputs, const std::string& call_target, + xla::Shape output_shape, bool has_side_effect, + const std::string& backend_config, const int api_version, + const std::unordered_map& frontend_attributes); std::string ToString() const override; @@ -22,6 +29,7 @@ class CustomCall : public XlaNode { bool has_side_effect_; std::string backend_config_; int api_version_; + std::unordered_map frontend_attributes_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index bb71d09173f..9c34f84c44a 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -564,7 +564,8 @@ std::vector custom_call( const std::vector& inputs, const std::string& target, const std::vector>& output_shapes, const std::vector& output_dtypes, bool has_side_effect, - const std::string& backend_config, const int api_version) { + const std::string& backend_config, const int api_version, + const std::unordered_map& frontend_attributes) { XLA_CHECK(inputs.size() > 0) << "inputs are empty"; std::vector values; @@ -584,7 +585,7 @@ std::vector custom_call( auto node = torch_xla::MakeNode( values, target, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), - has_side_effect, backend_config, api_version); + has_side_effect, backend_config, api_version, frontend_attributes); std::vector outputs; outputs.reserve(output_shapes.size()); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 8bd401b168e..c00f4dd52e6 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -92,7 +92,8 @@ std::vector custom_call( const std::vector& inputs, const std::string& target, const std::vector>& output_shapes, const std::vector& output_dtypes, bool has_side_effect, - const std::string& backend_config, const int api_version); + const std::string& backend_config, const int api_version, + const std::unordered_map& frontend_attributes); void custom_sharding_( const XLATensorPtr& input, diff --git a/torch_xla/experimental/stablehlo_custom_call.py b/torch_xla/experimental/stablehlo_custom_call.py index e729d0b7791..d27fd1544a5 100644 --- a/torch_xla/experimental/stablehlo_custom_call.py +++ b/torch_xla/experimental/stablehlo_custom_call.py @@ -10,10 +10,13 @@ def stablehlo_custom_call(args, output_dtypes, has_side_effect=False, backend_config="", - api_version=0): + api_version=0, + frontend_attributes=None): + frontend_attributes = frontend_attributes or {} res = torch_xla._XLAC._xla_custom_call(args, call_target, output_shapes, output_dtypes, has_side_effect, - backend_config, api_version) + backend_config, api_version, + frontend_attributes) if len(output_shapes) == 1: return res[0] return res @@ -29,3 +32,19 @@ def extract_custom_call_outputs_shape_dtype(n: torch.fx.Node): assert None not in output_shape_dtype output_shape, output_dtype = zip(*output_shape_dtype) return output_shape, output_dtype + + +def place_to_host(a: torch.Tensor): + return stablehlo_custom_call( + [a], + "annotate_device_placement", [a.shape], [a.dtype], + has_side_effect=True, + frontend_attributes={"_xla_buffer_placement": "pinned_host"}) + + +def place_to_device(a: torch.Tensor): + return stablehlo_custom_call( + [a], + "annotate_device_placement", [a.shape], [a.dtype], + has_side_effect=True, + frontend_attributes={"_xla_buffer_placement": "device"})