You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Glad to see that the composite op feature is added to Torch-XLA. I have tried this feature and got some questions, hope to get answers/suggestions here:
Some redundant IRs (start from custom_call) can't be erased after created the composite op, e.g. Gelu:
The erf op in main is useless and not erased. I have checked the composite op pass, it left these useless ops to later canonicalizer instead of erasing directly, but the canonicalizer didn't handle it... I guess it's caused by the custom call side-effect.
The question: Can the composite op pass erase these ops directly? Is any special reason to avoid the erasing operation here?
Composite op feature can't work in training. Even the proposal of this feature is for inference now (work for export API), I tried to enabled it in training locally, but I found that it reported a warning:
UserWarning: xla::mark_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /data4/home/luteng/code/pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Then the backward graph is not generated.
The question: Is any plan to support composite op feature in training? It seems the missing part is only to add the Autograd for mark_tensor, but I'm just a XLA developer and not familiar with PyTorch, I don't know how to add it...
The text was updated successfully, but these errors were encountered:
Can the composite op pass erase these ops directly? Is any special reason to avoid the erasing operation here?
I agree with the erf in main is expected to be removed by DCE, let me try to repro it on my end to investigate.
Regarding the composite op in training, the missing piece may not be adding autograd for mark_tensor. The current stablehlo export flow is integrated with torch.export, It seems that torch.export doesn't have training support yet.
❓ Questions and Help
Glad to see that the composite op feature is added to Torch-XLA. I have tried this feature and got some questions, hope to get answers/suggestions here:
custom_call
) can't be erased after created the composite op, e.g.Gelu
:The generated StableHLO is:
The
erf
op inmain
is useless and not erased. I have checked the composite op pass, it left these useless ops to latercanonicalizer
instead of erasing directly, but thecanonicalizer
didn't handle it... I guess it's caused by the custom call side-effect.The question: Can the composite op pass erase these ops directly? Is any special reason to avoid the erasing operation here?
Then the backward graph is not generated.
The question: Is any plan to support composite op feature in training? It seems the missing part is only to add the Autograd for
mark_tensor
, but I'm just a XLA developer and not familiar with PyTorch, I don't know how to add it...The text was updated successfully, but these errors were encountered: