Skip to content

Commit

Permalink
yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Nov 1, 2024
1 parent fc408f2 commit 374d4ab
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
15 changes: 9 additions & 6 deletions test/stablehlo/test_stablehlo_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import torch_xla.core.xla_model as xm
import torch_xla.experimental.stablehlo_custom_call
from torch.library import Library, impl, impl_abstract
from torch_xla.experimental.stablehlo_custom_call import (
stablehlo_custom_call, place_to_host, place_to_device)
from torch_xla.experimental.stablehlo_custom_call import (stablehlo_custom_call,
place_to_host,
place_to_device)
from torch_xla.stablehlo import (StableHLOExportOptions,
exported_program_to_stablehlo)

Expand Down Expand Up @@ -115,21 +116,23 @@ def forward(self, x):
# TODO: api version lost during conversion, or not shown in txt format.
# self.assertTrue("api_version = 1" in shlo_text)


def test_place_to_host_device(self):
dev = xm.xla_device()
a = torch.ones(10, device=dev)
b = place_to_host(a)
shlo_text = xm.get_stablehlo([b])
self.assertTrue("has_side_effect = true" in shlo_text)
self.assertTrue("mhlo.frontend_attributes = {_xla_buffer_placement = \"pinned_host\"}}" in shlo_text)
self.assertTrue(
"mhlo.frontend_attributes = {_xla_buffer_placement = \"pinned_host\"}}"
in shlo_text)

a = torch.ones(10, device=dev)
b = place_to_device(a)
shlo_text = xm.get_stablehlo([b])
self.assertTrue("has_side_effect = true" in shlo_text)
self.assertTrue("mhlo.frontend_attributes = {_xla_buffer_placement = \"device\"}}" in shlo_text)

self.assertTrue(
"mhlo.frontend_attributes = {_xla_buffer_placement = \"device\"}}" in
shlo_text)


if __name__ == "__main__":
Expand Down
21 changes: 13 additions & 8 deletions torch_xla/experimental/stablehlo_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def stablehlo_custom_call(args,
backend_config="",
api_version=0,
frontend_attributes=None):
frontend_attributes = frontend_attributes or {}
frontend_attributes = frontend_attributes or {}
res = torch_xla._XLAC._xla_custom_call(args, call_target, output_shapes,
output_dtypes, has_side_effect,
backend_config, api_version,
backend_config, api_version,
frontend_attributes)
if len(output_shapes) == 1:
return res[0]
Expand All @@ -35,11 +35,16 @@ def extract_custom_call_outputs_shape_dtype(n: torch.fx.Node):


def place_to_host(a: torch.Tensor):
return stablehlo_custom_call([a], "annotate_device_placement",
[a.shape],[a.dtype], has_side_effect=True,
frontend_attributes={"_xla_buffer_placement": "pinned_host"})
return stablehlo_custom_call(
[a],
"annotate_device_placement", [a.shape], [a.dtype],
has_side_effect=True,
frontend_attributes={"_xla_buffer_placement": "pinned_host"})


def place_to_device(a: torch.Tensor):
return stablehlo_custom_call([a], "annotate_device_placement",
[a.shape],[a.dtype], has_side_effect=True,
frontend_attributes={"_xla_buffer_placement": "device"})
return stablehlo_custom_call(
[a],
"annotate_device_placement", [a.shape], [a.dtype],
has_side_effect=True,
frontend_attributes={"_xla_buffer_placement": "device"})

0 comments on commit 374d4ab

Please sign in to comment.