Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2 questions for the composite op feature #8486

Open
Zantares opened this issue Dec 12, 2024 · 3 comments
Open

2 questions for the composite op feature #8486

Zantares opened this issue Dec 12, 2024 · 3 comments
Assignees
Labels
stablehlo StableHLO related work

Comments

@Zantares
Copy link

❓ Questions and Help

Glad to see that the composite op feature is added to Torch-XLA. I have tried this feature and got some questions, hope to get answers/suggestions here:

  1. Some redundant IRs (start from custom_call) can't be erased after created the composite op, e.g. Gelu:
import torch
import torch_xla
import torch_xla.core.xla_model as xm

from torch_xla import stablehlo
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder

class Example(torch.nn.Module):
    def __init__(self):
        super(Example, self).__init__()
        self.gelu = torch.nn.GELU(approximate="none")
        self.composite_op = StableHLOCompositeBuilder("composite.gelu", {"approximate": "none"})

    def forward(self, x):
        x = self.composite_op.mark_inputs(x)
        y = self.gelu(x)
        y = self.composite_op.mark_outputs(y)
        return y

x = torch.randn(10, device=xm.xla_device())
model = Example().to(xm.xla_device())
print(model(x))

input_args = (x, )
exported = torch.export.export(model, input_args)
# print(exported.graph)
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported)
stablehlo = stablehlo_gm.get_stablehlo_text()
print(stablehlo)

The generated StableHLO is:

module @IrToHlo.16 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
    %cst = stablehlo.constant dense<0.707106769> : tensor<10xf32>
    %0 = stablehlo.multiply %arg0, %cst : tensor<10xf32>
    %1 = stablehlo.custom_call @mhlo.erf(%0) {mhlo.attributes = {}, mhlo.version = 1 : i64} : (tensor<10xf32>) -> tensor<10xf32>
    %2 = stablehlo.composite "composite.gelu" %arg0 {composite_attributes = {approximate = "none"}, decomposition = @composite.gelu.impl} : (tensor<10xf32>) -> tensor<10xf32>
    return %2 : tensor<10xf32>
  }
  func.func private @composite.gelu.impl(%arg0: tensor<10xf32>) -> tensor<10xf32> {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<10xf32>
    %cst_0 = stablehlo.constant dense<0.707106769> : tensor<10xf32>
    %cst_1 = stablehlo.constant dense<5.000000e-01> : tensor<10xf32>
    %0 = stablehlo.multiply %arg0, %cst_1 : tensor<10xf32>
    %1 = stablehlo.multiply %arg0, %cst_0 : tensor<10xf32>
    %2 = stablehlo.custom_call @mhlo.erf(%1) {mhlo.attributes = {}, mhlo.version = 1 : i64} : (tensor<10xf32>) -> tensor<10xf32>
    %3 = stablehlo.add %2, %cst : tensor<10xf32>
    %4 = stablehlo.multiply %0, %3 : tensor<10xf32>
    return %4 : tensor<10xf32>
  }
}

The erf op in main is useless and not erased. I have checked the composite op pass, it left these useless ops to later canonicalizer instead of erasing directly, but the canonicalizer didn't handle it... I guess it's caused by the custom call side-effect.

The question: Can the composite op pass erase these ops directly? Is any special reason to avoid the erasing operation here?

  1. Composite op feature can't work in training. Even the proposal of this feature is for inference now (work for export API), I tried to enabled it in training locally, but I found that it reported a warning:

UserWarning: xla::mark_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /data4/home/luteng/code/pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

Then the backward graph is not generated.

The question: Is any plan to support composite op feature in training? It seems the missing part is only to add the Autograd for mark_tensor, but I'm just a XLA developer and not familiar with PyTorch, I don't know how to add it...

@lsy323 lsy323 self-assigned this Dec 13, 2024
@lsy323 lsy323 added the stablehlo StableHLO related work label Dec 13, 2024
@lsy323
Copy link
Collaborator

lsy323 commented Dec 16, 2024

Hi @Zantares, thank you for reporting the issue!

Can the composite op pass erase these ops directly? Is any special reason to avoid the erasing operation here?

I agree with the erf in main is expected to be removed by DCE, let me try to repro it on my end to investigate.

Regarding the composite op in training, the missing piece may not be adding autograd for mark_tensor. The current stablehlo export flow is integrated with torch.export, It seems that torch.export doesn't have training support yet.

@Zantares
Copy link
Author

Thanks for the answer @lsy323 ! According to the reply, we can focus on the 1st question of redundant ops in this issue.

For the 2nd question, I have found the Aten op lowering process in Torch-XLA, I'd like to submit a draft PR later to see if it's acceptable.

@Zantares
Copy link
Author

Hi @lsy323 , I added a draft PR #8502 to demonstrate the solution of composite op in training. Hope to get feedback/suggestion, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stablehlo StableHLO related work
Projects
None yet
Development

No branches or pull requests

2 participants