diff --git a/test/test_operations.py b/test/test_operations.py index e67722dfec2..af3bd1d3eac 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1924,6 +1924,25 @@ def test_tpu_custom_call_pallas_raise(self): torch_xla._XLAC._xla_tpu_custom_call_(output, [], payload) output.cpu() + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_tpu_custom_call_pallas_add_one_dynamo(self): + # This payload is generated by the following Pallas code: + # def add_vectors_kernel(x_ref, o_ref): + # o_ref[...] = x_ref[...] + 1 + payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==\", \"needs_layout_passes\": true}}" + + x = torch.arange(8, dtype=torch.int).to("xla") + expected_output = x + 1 + output = torch.arange(8, dtype=torch.int).to("xla") + + import torch_xla.experimental.custom_kernel + def add_one_pallas(output, inputs, payload): + torch.ops.xla.tpu_custom_call_(output, inputs, payload) + compiled_add_one_pallas = torch.compile(add_one_pallas, backend='openxla', fullgraph=True) + + compiled_add_one_pallas(output, [x], payload) + self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu())) + class MNISTComparator(nn.Module): diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py new file mode 100644 index 00000000000..167cea13cac --- /dev/null +++ b/torch_xla/experimental/custom_kernel.py @@ -0,0 +1,20 @@ +import torch +import torch_xla + +from typing import List +from torch.library import impl +from torch_xla.core.xla_model import XLA_LIB + +XLA_LIB.define( + "tpu_custom_call_(Tensor(a!) output, Tensor[] inputs, str payload) -> ()", +) + +@impl(XLA_LIB, "tpu_custom_call_", "XLA") +def tpu_custom_call_xla_(output: torch.Tensor, inputs: List[torch.Tensor], payload: str): + torch_xla._XLAC._xla_tpu_custom_call_(output, inputs, payload) + + +@impl(XLA_LIB, "tpu_custom_call_", "CompositeExplicitAutograd") +def tpu_custom_call_(output: torch.Tensor, inputs: List[torch.Tensor], payload: str): + # Do nothing for non-xla tensor. + return