Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions to emit custom call to place a buffer to host and device. #8350

Merged
merged 2 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion test/stablehlo/test_stablehlo_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import torch_xla.core.xla_model as xm
import torch_xla.experimental.stablehlo_custom_call
from torch.library import Library, impl, impl_abstract
from torch_xla.experimental.stablehlo_custom_call import stablehlo_custom_call
from torch_xla.experimental.stablehlo_custom_call import (stablehlo_custom_call,
place_to_host,
place_to_device)
from torch_xla.stablehlo import (StableHLOExportOptions,
exported_program_to_stablehlo)

Expand Down Expand Up @@ -114,6 +116,24 @@ def forward(self, x):
# TODO: api version lost during conversion, or not shown in txt format.
# self.assertTrue("api_version = 1" in shlo_text)

def test_place_to_host_device(self):
dev = xm.xla_device()
a = torch.ones(10, device=dev)
b = place_to_host(a)
shlo_text = xm.get_stablehlo([b])
self.assertTrue("has_side_effect = true" in shlo_text)
self.assertTrue(
"mhlo.frontend_attributes = {_xla_buffer_placement = \"pinned_host\"}}"
in shlo_text)

a = torch.ones(10, device=dev)
b = place_to_device(a)
shlo_text = xm.get_stablehlo([b])
self.assertTrue("has_side_effect = true" in shlo_text)
self.assertTrue(
"mhlo.frontend_attributes = {_xla_buffer_placement = \"device\"}}" in
shlo_text)


if __name__ == "__main__":

Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2555,8 +2555,9 @@ void InitXlaModuleBindings(py::module m) {
[](const std::vector<at::Tensor>& inputs, const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes, bool has_side_effect,
const std::string& backend_config,
const int api_version) -> std::vector<at::Tensor> {
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>&
frontend_attributes) -> std::vector<at::Tensor> {
std::vector<at::ScalarType> dtypes;
dtypes.reserve(output_dtypes.size());
for (auto& dtype : output_dtypes) {
Expand All @@ -2566,7 +2567,8 @@ void InitXlaModuleBindings(py::module m) {

auto xtensors = tensor_methods::custom_call(
bridge::GetXlaTensors(inputs), target, output_shapes, dtypes,
has_side_effect, backend_config, api_version);
has_side_effect, backend_config, api_version,
frontend_attributes);
return bridge::AtenFromXlaTensors(std::move(xtensors));
});
m.def("_xla_tpu_custom_call",
Expand Down
26 changes: 21 additions & 5 deletions torch_xla/csrc/ops/custom_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,27 @@

namespace torch_xla {

CustomCall::CustomCall(torch::lazy::OpList inputs,
const std::string& call_target, xla::Shape output_shape,
bool has_side_effect, const std::string& backend_config,
const int api_version)
CustomCall::CustomCall(
torch::lazy::OpList inputs, const std::string& call_target,
xla::Shape output_shape, bool has_side_effect,
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>& frontend_attributes)
: XlaNode(xla_custom_call, inputs, std::move(output_shape),
/*num_outputs=*/output_shape.tuple_shapes_size(),
torch::lazy::MHash(call_target)),
call_target_(call_target),
has_side_effect_(has_side_effect),
backend_config_(backend_config),
api_version_(api_version) {}
api_version_(api_version),
frontend_attributes_(frontend_attributes) {}

CustomCall::CustomCall(torch::lazy::OpList inputs,
const std::string& call_target, xla::Shape output_shape,
bool has_side_effect, const std::string& backend_config,
const int api_version)
: CustomCall(inputs, call_target, output_shape, has_side_effect,
backend_config, api_version,
std::unordered_map<std::string, std::string>()) {}

torch::lazy::NodePtr CustomCall::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<CustomCall>(operands, call_target_,
Expand All @@ -38,6 +48,12 @@ XlaOpVector CustomCall::Lower(LoweringContext* loctx) const {
output_shape = output_shape.tuple_shapes(0);
}
XLA_CHECK(api_version_ >= 0 && api_version_ < 5);

xla::FrontendAttributes feattr;
feattr.mutable_map()->insert(frontend_attributes_.begin(),
frontend_attributes_.end());
xla::XlaScopedFrontendAttributesAssignment feattr_assign(inputs[0].builder(),
feattr);
xla::XlaOp output = xla::CustomCall(
inputs[0].builder(), call_target_, inputs, output_shape,
/*opaque=*/backend_config_,
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/ops/custom_call.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_
#define XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_

#include <unordered_map>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
Expand All @@ -10,6 +12,11 @@ class CustomCall : public XlaNode {
CustomCall(torch::lazy::OpList inputs, const std::string& call_target,
xla::Shape output_shape, bool has_side_effect,
const std::string& backend_config, const int api_version);
CustomCall(
torch::lazy::OpList inputs, const std::string& call_target,
xla::Shape output_shape, bool has_side_effect,
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>& frontend_attributes);

std::string ToString() const override;

Expand All @@ -22,6 +29,7 @@ class CustomCall : public XlaNode {
bool has_side_effect_;
std::string backend_config_;
int api_version_;
std::unordered_map<std::string, std::string> frontend_attributes_;
};

} // namespace torch_xla
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,8 @@ std::vector<XLATensorPtr> custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
const std::string& backend_config, const int api_version) {
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>& frontend_attributes) {
XLA_CHECK(inputs.size() > 0) << "inputs are empty";

std::vector<torch::lazy::Value> values;
Expand All @@ -584,7 +585,7 @@ std::vector<XLATensorPtr> custom_call(

auto node = torch_xla::MakeNode<CustomCall>(
values, target, xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
has_side_effect, backend_config, api_version);
has_side_effect, backend_config, api_version, frontend_attributes);

std::vector<XLATensorPtr> outputs;
outputs.reserve(output_shapes.size());
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ std::vector<XLATensorPtr> custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
const std::string& backend_config, const int api_version);
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>& frontend_attributes);

void custom_sharding_(
const XLATensorPtr& input,
Expand Down
23 changes: 21 additions & 2 deletions torch_xla/experimental/stablehlo_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ def stablehlo_custom_call(args,
output_dtypes,
has_side_effect=False,
backend_config="",
api_version=0):
api_version=0,
frontend_attributes=None):
frontend_attributes = frontend_attributes or {}
res = torch_xla._XLAC._xla_custom_call(args, call_target, output_shapes,
output_dtypes, has_side_effect,
backend_config, api_version)
backend_config, api_version,
frontend_attributes)
if len(output_shapes) == 1:
return res[0]
return res
Expand All @@ -29,3 +32,19 @@ def extract_custom_call_outputs_shape_dtype(n: torch.fx.Node):
assert None not in output_shape_dtype
output_shape, output_dtype = zip(*output_shape_dtype)
return output_shape, output_dtype


def place_to_host(a: torch.Tensor):
return stablehlo_custom_call(
[a],
"annotate_device_placement", [a.shape], [a.dtype],
has_side_effect=True,
frontend_attributes={"_xla_buffer_placement": "pinned_host"})


def place_to_device(a: torch.Tensor):
return stablehlo_custom_call(
[a],
"annotate_device_placement", [a.shape], [a.dtype],
has_side_effect=True,
frontend_attributes={"_xla_buffer_placement": "device"})
Loading