Skip to content

Commit

Permalink
Convert stablehlo.composite ops to func.call ops before converting to…
Browse files Browse the repository at this point in the history
… TTIR so that the TTIR inliner can inline them
  • Loading branch information
LPanosTT committed Oct 24, 2024
1 parent 2571a98 commit 915008c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/Conversion/StableHLOToTTIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ add_mlir_library(TTMLIRStableHLOToTTIR
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
StablehloPasses
)
2 changes: 2 additions & 0 deletions lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

#include "stablehlo/transforms/Passes.h"
#include "ttmlir/Conversion/Passes.h"

namespace mlir::tt::ttir {
Expand All @@ -20,6 +21,7 @@ void createStableHLOToTTIRPipeline(
if (options.arithDialectConversionsEnabled) {
pm.addPass(createConvertArithToStableHLOPass());
}
pm.addPass(stablehlo::createStablehloLegalizeCompositeToCallPass());
pm.addPass(createConvertStableHLOToTTIRPass());
if (options.removeDeadValuesEnabled) {
pm.addPass(mlir::createRemoveDeadValuesPass());
Expand Down
16 changes: 16 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/composite_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_eltwise_add attributes {} {
func.func private @add_impl(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

func.func public @main(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%results = stablehlo.composite "jit_eltwise_add.my_add" %arg0, %arg1 {
decomposition = @add_impl
} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[C:.*]] = call @add_impl [[C:.*]]
return %results : tensor<13x21x3xf32>
}
}

0 comments on commit 915008c

Please sign in to comment.