Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conversion pass for Arith ConstantOp #953

Merged
merged 6 commits into from
Oct 22, 2024

Conversation

uazizTT
Copy link
Contributor

@uazizTT uazizTT commented Oct 21, 2024

Add an optional pass to convert arith.constant to stablehlo.constant

This is needed to run examples in tt-torch until fixed in upstream.
https://github.com/AleksKnezevic/tt-torch/issues/1

@nsmithtt
Copy link
Contributor

Can you cut and paste an example in the commit message / as a comment on this PR? So torch-mlir is mixing arith and stablehlo together? I suppose xla doesn't have any cases like this so far.

@uazizTT
Copy link
Contributor Author

uazizTT commented Oct 21, 2024

Can you cut and paste an example in the commit message / as a comment on this PR? So torch-mlir is mixing arith and stablehlo together? I suppose xla doesn't have any cases like this so far.

Yes so seems torch-mlir does covert the torch Dialect to StableHLO, but any ops of native MLIR Dialects e.g. arith are not converted.

I referenced the issue but tt-torch is a private repo so not everyone can see it, posting the example here for reference:

module {
  func.func @main(%arg0: tensor<32x32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
    %cst = arith.constant dense<1> : tensor<1xi64>
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<32x32xf32>) -> tensor<32x32xf32>
    %1 = stablehlo.dot_general %arg2, %0, contracting_dims = [1] x [0] : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
    %2 = stablehlo.convert %cst : (tensor<1xi64>) -> tensor<1xf32>
    %3 = stablehlo.reshape %2 : (tensor<1xf32>) -> tensor<f32>
    %4 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<32x32xf32>) -> tensor<32x32xf32>
    %5 = stablehlo.broadcast_in_dim %3, dims = [] : (tensor<f32>) -> tensor<32x32xf32>
    %6 = stablehlo.multiply %4, %5 : tensor<32x32xf32>
    %7 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %8 = stablehlo.broadcast_in_dim %3, dims = [] : (tensor<f32>) -> tensor<32xf32>
    %9 = stablehlo.multiply %7, %8 : tensor<32xf32>
    %10 = stablehlo.broadcast_in_dim %6, dims = [0, 1] : (tensor<32x32xf32>) -> tensor<32x32xf32>
    %11 = stablehlo.broadcast_in_dim %9, dims = [1] : (tensor<32xf32>) -> tensor<32x32xf32>
    %12 = stablehlo.add %10, %11 : tensor<32x32xf32>
    return %12 : tensor<32x32xf32>
  }
}

@uazizTT uazizTT force-pushed the uaziz/arith-constant-conversion branch from 0943350 to 8d2f32a Compare October 21, 2024 15:42
@nsmithtt
Copy link
Contributor

I'm wondering if we should support arith -> ttir conversion as part of the stablehlo path and just go straight from arith.constant to ttir.constant.

My thinking is that presumably linalg/tosa might emit the same thing and they'd also need arith conversion. Can you try emitting those dialects from torch-mlir to see what the resulting IR is?

@uazizTT
Copy link
Contributor Author

uazizTT commented Oct 21, 2024

I'm wondering if we should support arith -> ttir conversion as part of the stablehlo path and just go straight from arith.constant to ttir.constant.

The reason I chose to first convert from arith.ConstantOp to stablehlo.ConstantOp is that we apply custom type conversion for ElementType and Shape when converting the stablehlo.ConstantOp to ttir.ConstantOp, so keeping the arith to stablehlo conversion light weight, otherwise need to replicate the custom type conversion in both places.

There is an open PR with more changes here:
#802

My thinking is that presumably linalg/tosa might emit the same thing and they'd also need arith conversion. Can you try emitting those dialects from torch-mlir to see what the resulting IR is?

Ok I will try a few examples.

@uazizTT
Copy link
Contributor Author

uazizTT commented Oct 21, 2024

My thinking is that presumably linalg/tosa might emit the same thing and they'd also need arith conversion. Can you try emitting those dialects from torch-mlir to see what the resulting IR is?

Ok I ran a few examples and I found that:

  • For TorchToTosa, there is no Arith dialects constants that remain in the few examples I ran, some constants are converted to tosa.const and others remain as torch.constant. e.g. %int3 = torch.constant.int 3

  • For TorchToLinalg, the output is similar to stablehlo as arith.constant is not converted.

@AleksKnezevic
Copy link
Contributor

I had a slight preference in converting straight to ttir, but since torch-mlir only emits arith.constant (and no other ops) and since our constant op is so complicated as is, I think we'll have fewer issues overall if we simply do a thin conversion from airth to stablehlo/linang and continue with the default path as @uazizTT did in the pr.

@nsmithtt
Copy link
Contributor

I had a slight preference in converting straight to ttir, but since torch-mlir only emits arith.constant (and no other ops) and since our constant op is so complicated as is, I think we'll have fewer issues overall if we simply do a thin conversion from airth to stablehlo/linang and continue with the default path as @uazizTT did in the pr.

Sounds good, let's move forward with this approach. We can just duplicate it for linalg path, it's pretty contained if it's just arith.constant

@uazizTT uazizTT force-pushed the uaziz/arith-constant-conversion branch from 3f40604 to 030040d Compare October 22, 2024 14:09
@uazizTT uazizTT merged commit 684d818 into main Oct 22, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants