Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Lowering Aten op to composite op instead of small ops #8502

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

Zantares
Copy link

This PR is to solve the 2nd question in this issue: supports composite op in training.

Motivation

Composite op is beneficial for performance optimization and we aim to apply it to training too. . According to the response in the issue, the community has no plan to extend this to training currently... Thus, I created this draft PR to demonstrate our intention.

Detail

This PR alters the Aten op lowering process when there isn't a 1:1 mapping to XLA op. It uses composite call instead of small XLA ops. Later, in the optimization process, the composite call can be easily replaced with a custom kernel or decomposed.

This is still a draft PR and only Gelu is implemented as an example. If it gets accepted, here are some further suggestions:

  1. Keep both the decomposed ops and the composite call implementation. Use a new env setting (e.g. XLA_COMPOSITE_OP) to enable this feature. Also, add an op list setting to define which ops can be composed.
  2. Only retain the composite call implementation as it can be easily decomposed by StableHLO pass. User can control the behavior by turning the decompose pass on or off.

Example

import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()
gelu = torch.nn.GELU(approximate="none")

x = torch.tensor([2.0], requires_grad=True, device=device)
y = gelu(x ** 2)
y.backward()

print(x.grad)

With this PR, the generated StableHLO is:

module @SyncTensorsGraph.43 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func private @composite.gelu_backward.14(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
    %cst = stablehlo.constant dense<0.398942292> : tensor<1xf32>
    %cst_0 = stablehlo.constant dense<-5.000000e-01> : tensor<1xf32>
    %cst_1 = stablehlo.constant dense<5.000000e-01> : tensor<1xf32>
    %cst_2 = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
    %cst_3 = stablehlo.constant dense<0.707106769> : tensor<1xf32>
    %0 = stablehlo.multiply %arg1, %cst_3 : tensor<1xf32>
    %1 = stablehlo.custom_call @mhlo.erf(%0) {mhlo.attributes = {}, mhlo.version = 1 : i64} : (tensor<1xf32>) -> tensor<1xf32>
    %2 = stablehlo.add %1, %cst_2 : tensor<1xf32>
    %3 = stablehlo.multiply %2, %cst_1 : tensor<1xf32>
    %4 = stablehlo.multiply %arg1, %arg1 : tensor<1xf32>
    %5 = stablehlo.multiply %4, %cst_0 : tensor<1xf32>
    %6 = stablehlo.exponential %5 : tensor<1xf32>
    %7 = stablehlo.multiply %arg1, %6 : tensor<1xf32>
    %8 = stablehlo.multiply %7, %cst : tensor<1xf32>
    %9 = stablehlo.add %3, %8 : tensor<1xf32>
    %10 = stablehlo.multiply %arg0, %9 : tensor<1xf32>
    return %10 : tensor<1xf32>
  }
  func.func @main(%arg0: tensor<f32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
    %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<1xf32>
    %0 = stablehlo.power %arg1, %cst_0 : tensor<1xf32>
    %1 = stablehlo.composite "composite.gelu_backward" %cst, %0 {composite_attributes = {approximate = "none"}, decomposition = @composite.gelu_backward.14, version = 1 : i32} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
    %2 = stablehlo.power %arg1, %cst : tensor<1xf32>
    %3 = stablehlo.reshape %arg0 : (tensor<f32>) -> tensor<1xf32>
    %4 = stablehlo.multiply %2, %3 : tensor<1xf32>
    %5 = stablehlo.multiply %1, %4 : tensor<1xf32>
    return %5 : tensor<1xf32>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant