Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed Feb 5, 2024
1 parent 3e68409 commit bc6b14c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
19 changes: 19 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,25 @@ def test_tpu_custom_call_pallas_raise(self):
torch_xla._XLAC._xla_tpu_custom_call_(output, [], payload)
output.cpu()

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_add_one_dynamo(self):
# This payload is generated by the following Pallas code:
# def add_vectors_kernel(x_ref, o_ref):
# o_ref[...] = x_ref[...] + 1
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==\", \"needs_layout_passes\": true}}"

x = torch.arange(8, dtype=torch.int).to("xla")
expected_output = x + 1
output = torch.arange(8, dtype=torch.int).to("xla")

import torch_xla.experimental.custom_kernel
def add_one_pallas(output, inputs, payload):
torch.ops.xla.tpu_custom_call_(output, inputs, payload)
compiled_add_one_pallas = torch.compile(add_one_pallas, backend='openxla', fullgraph=True)

compiled_add_one_pallas(output, [x], payload)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))


class MNISTComparator(nn.Module):

Expand Down
20 changes: 20 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import torch_xla

from typing import List
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB

XLA_LIB.define(
"tpu_custom_call_(Tensor(a!) output, Tensor[] inputs, str payload) -> ()",
)

@impl(XLA_LIB, "tpu_custom_call_", "XLA")
def tpu_custom_call_xla_(output: torch.Tensor, inputs: List[torch.Tensor], payload: str):
torch_xla._XLAC._xla_tpu_custom_call_(output, inputs, payload)


@impl(XLA_LIB, "tpu_custom_call_", "CompositeExplicitAutograd")
def tpu_custom_call_(output: torch.Tensor, inputs: List[torch.Tensor], payload: str):
# Do nothing for non-xla tensor.
return

0 comments on commit bc6b14c

Please sign in to comment.