Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple StableHLO Composite outputs #6295

Merged
merged 6 commits into from
Feb 4, 2024

Conversation

chunnienc
Copy link
Collaborator

@chunnienc chunnienc commented Jan 11, 2024

This PR adds the support for marking multiple tensors in StableHLOCompositeBuilder.mark_outputs.

Example:

class Model(torch.nn.Module):

      def __init__(self):
        super().__init__()

      def forward(self, x, y):
        builder = StableHLOCompositeBuilder("sample_composite")
        x, y = builder.mark_inputs(x, y)
        a = x + y
        b = x - y
        c = x + 1
        a, b, c = builder.mark_outputs(a, b, c)
        return a + b + c

SHLO:

module @IrToHlo.42 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> tensor<5x5xf32> {
    %0:3 = stablehlo.custom_call @stablehlo.composite(%arg0, %arg1) {called_computations = [@sample_composite.impl], composite.backend_config = {attributes = {}, name = "sample_composite"}} : (tensor<5x5xf32>, tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<5x5xf32>, tensor<5x5xf32>)
    %1 = stablehlo.add %0#0, %0#1 : tensor<5x5xf32>
    %2 = stablehlo.add %1, %0#2 : tensor<5x5xf32>
    return %2 : tensor<5x5xf32>
  }
  func.func private @sample_composite.impl(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<5x5xf32>, tensor<5x5xf32>) {
    %0 = stablehlo.constant dense<1.000000e+00> : tensor<5x5xf32>
    %1 = stablehlo.add %arg0, %arg1 : tensor<5x5xf32>
    %2 = stablehlo.subtract %arg0, %arg1 : tensor<5x5xf32>
    %3 = stablehlo.add %arg0, %0 : tensor<5x5xf32>
    return %1, %2, %3 : tensor<5x5xf32>, tensor<5x5xf32>, tensor<5x5xf32>
  }
}

@chunnienc chunnienc requested a review from lsy323 January 11, 2024 08:15
@miladm
Copy link
Collaborator

miladm commented Jan 11, 2024

Suggesting to include references to issues on public URLs on GH. Thanks very much.

@chunnienc chunnienc marked this pull request as ready for review February 1, 2024 06:17
@chunnienc chunnienc requested a review from qihqi February 1, 2024 18:09
Copy link
Collaborator

@lsy323 lsy323 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, left a few comments/questions. Thanks @chunnienc!

@chunnienc chunnienc merged commit 535d398 into master Feb 4, 2024
18 checks passed
amithrm pushed a commit to amithrm/xla that referenced this pull request Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants