Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-traceable Custom Ops with opaque arguments #7330

Open
tle-huu opened this issue Jun 22, 2024 · 5 comments
Open

Support non-traceable Custom Ops with opaque arguments #7330

tle-huu opened this issue Jun 22, 2024 · 5 comments
Assignees

Comments

@tle-huu
Copy link

tle-huu commented Jun 22, 2024

🚀 Feature

torch_xla.stablehlo supports exporting custom op to stablehlo custom call for tensors arguments.
We would like to be able to export custom ops taking arbitrary opaque string as argument to stable hlo.

Motivation

Some custom operations come from C external sources and are used through pybindings during inference.

Those operations sometimes take POD structures that are not necessarily tensors as argument, a little bit like the opaque Descriptor example in the Jax custom op tutorial.

Such operations can be used at any point in the model, they usually are ([opaque structs]) -> tensors, or (tensors) -> [opaque struct], but we could imagine an op in the middle of a model having side effect to an opaque external structure.

Pitch

Here is the example pytorch codes and what the HLO could potentially look like.

The idea is to be able to declare some arguments as "external" for the export to have them in the upper function and annotate them with some attributes, which would be used downstream to lower to some opaque pointers and sizes.

My example is based of #7017

@impl(m, "custom_op_external", "XLA")
def custom_op_external_xla(external_input, x):
  res = stablehlo_custom_call((external_input,x), "custom_op_external", [(external.shape[1], ), x.shape[1:]],
                                  [torch.int8, torch.int8], True, "backend_config", 1)
  return res

class M(torch.nn.Module):

   self.external = torch.empty(32)

  def forward(self, x):
    x = torch.sin(x)
    x = torch.ops.my_custom_library.custom_op_external(self.external, x)
    x = x + 1
    return x

ep = torch.export.export(M(), (torch.randn(3, 3), ))
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
module @IrToHlo.10 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<32xi8> {external: true}) -> tensor<3xi8> {
    %c = stablehlo.constant dense<1> : tensor<3xi8>
    %0 = stablehlo.sine %arg0 : tensor<3x3xf32>
    %1 = stablehlo.custom_call @custom_op_external(%arg1, %0) {backend_config = "backend_config", has_side_effect = true } : (tensor<3x3xf32>) -> tensor<3xi8>
    %2 = stablehlo.add %1, %c : tensor<3xi8>
    return %2 : tensor<3xi8>
  }
}

@JackCaoG
Copy link
Collaborator

@qihqi @lsy323

@qihqi
Copy link
Collaborator

qihqi commented Jun 25, 2024

Currently I don't think you can register a custom op to torch with types that are not defined in native_functions.yml. Which, they do have str as dtype.

Also, I am curious why not just use int tensors to hold bytes as you have shown it the example above. That should already works.

@ManfeiBai
Copy link
Collaborator

Hi, @qihqi, is that ok to assign this ticket to you?

@tle-huu
Copy link
Author

tle-huu commented Jun 25, 2024

Also, I am curious why not just use int tensors to hold bytes as you have shown it the example above. That should already works.

We would like to expose already testesd and optimized implementations in C, that do not necessarily take tensors.

We coud definitely use a "Tensor" object to hold opaque / random bytes and reinterpret cast in the implementation (and that is some of the idea), but to know if suche a tensor holds an actual tensor or an opaque string, it needs to be annotated somehow when we export from torch to HLO. The annotations can then be used to lower the custom call arguments accordingly to opaque pointers.

We could maybe use #7046 to introduce annotation, but i was thinking maybe we could find a more generic solution.

@ManfeiBai
Copy link
Collaborator

Hi, @lsy323, is that ok to assign this ticket to you too? since @qihqi is OOO now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants