Skip to content

Commit

Permalink
Add conversion pass for Arith ConstantOp (#953)
Browse files Browse the repository at this point in the history
* Add pass to convert arith Dialect ops to stablehloDialect ops starting with ConstantOp. Add test for arith.constant to ttir.constant conversion.
  • Loading branch information
uazizTT authored Oct 22, 2024
1 parent d6d6aea commit 684d818
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 0 deletions.
19 changes: 19 additions & 0 deletions include/ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_CONVERSION_ARITHTOSTABLEHLO_ARITHTOSTABLEHLO_H
#define TTMLIR_CONVERSION_ARITHTOSTABLEHLO_ARITHTOSTABLEHLO_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::tt {

#ifdef TTMLIR_ENABLE_STABLEHLO
std::unique_ptr<OperationPass<ModuleOp>> createConvertArithToStableHLOPass();
#endif

} // namespace mlir::tt

#endif // TTMLIR_CONVERSION_STABLEHLOTOTTIR_STABLEHLOTOTTIR_H
1 change: 1 addition & 0 deletions include/ttmlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define TTMLIR_CONVERSION_PASSES_H

#ifdef TTMLIR_ENABLE_STABLEHLO
#include "ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h"
#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"
#endif
#include "ttmlir/Conversion/TTIRToTTMetal/TTIRToTTMetal.h"
Expand Down
5 changes: 5 additions & 0 deletions include/ttmlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ let summary = "Convert StableHLO dialect to TTIR dialect.";
let constructor = "createConvertStableHLOToTTIRPass()";
let dependentDialects = ["mlir::stablehlo::StablehloDialect", "mlir::tt::ttir::TTIRDialect"];
}
def ConvertArithToStableHLO : Pass<"convert-arith-to-stablehlo", "::mlir::ModuleOp"> {
let summary = "Convert Arith Dialect to StableHLO dialect.";
let constructor = "createConvertArithToStableHLOPass()";
let dependentDialects = ["mlir::stablehlo::StablehloDialect", "mlir::arith::ArithDialect"];
}
#endif

def ConvertTosaToTTIR : Pass<"convert-tosa-to-ttir", "::mlir::ModuleOp"> {
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ struct StableHLOToTTIRPipelineOptions
// Currently this pass fails if module has a name, so keeping the
// optimization OFF by default until that issue is fixed on llvm side.
llvm::cl::init(false)};
Option<bool> arithDialectConversionsEnabled{
*this, "enable-arith-to-stablehlo",
llvm::cl::desc("Enable Arith to StableHLO conversion pass."),
// Currently torch-mlir front-end does not convert ConstantOp for Arith
// Dialect to StableHLO. This pass makes those conversions until this
// is fixed in the upstream torch-mlir.
llvm::cl::init(true)};
};

void createStableHLOToTTIRPipeline(
Expand Down
94 changes: 94 additions & 0 deletions lib/Conversion/StableHLOToTTIR/ArithToStableHLOPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h"

#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Pass/Pass.h>

#include <stablehlo/dialect/StablehloOps.h>

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"

using namespace mlir;
using namespace mlir::tt;

namespace mlir::tt::ttir {

#define GEN_PASS_DEF_CONVERTARITHTOSTABLEHLO
#include "ttmlir/Conversion/Passes.h.inc"

} // namespace mlir::tt::ttir

namespace {

class ArithToStableHLOConstantOpConversionPattern
: public OpConversionPattern<mlir::arith::ConstantOp> {

using OpConversionPattern<mlir::arith::ConstantOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::arith::ConstantOp srcOp,
mlir::arith::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(srcOp,
srcOp.getValue());
return success();
}
};

struct ConvertArithToStableHLOPass
: public ttir::impl::ConvertArithToStableHLOBase<
ConvertArithToStableHLOPass> {
void runOnOperation() final {
mlir::ConversionTarget target(getContext());

target.addIllegalDialect<mlir::arith::ArithDialect>();
target.addLegalDialect<mlir::stablehlo::StablehloDialect>();
target.addLegalOp<mlir::tensor::EmptyOp>();
target.addLegalOp<mlir::ModuleOp>();
target.addLegalOp<mlir::func::FuncOp>();
target.addLegalOp<mlir::func::ReturnOp>();

// For now keep the same type assuming StableHLO ops operate on builtin
// tensor.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) {
assert(isa<RankedTensorType>(type) &&
"only ranked tensor type supported");
return type;
});
RewritePatternSet patterns(&getContext());

// Convert Arith ConstantOp to StableHLO ConstantOp
patterns.add<ArithToStableHLOConstantOpConversionPattern>(typeConverter,
&getContext());

// Apply conversion.
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
};

} // namespace

namespace mlir::tt {

std::unique_ptr<OperationPass<ModuleOp>> createConvertArithToStableHLOPass() {
return std::make_unique<ConvertArithToStableHLOPass>();
}

} // namespace mlir::tt
1 change: 1 addition & 0 deletions lib/Conversion/StableHLOToTTIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ include_directories(${PROJECT_SOURCE_DIR}/include)
add_mlir_library(TTMLIRStableHLOToTTIR
StableHLOToTTIRPatterns.cpp
StableHLOToTTIRPass.cpp
ArithToStableHLOPass.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/StableHLOToTTIR
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"

#include <llvm/ADT/ArrayRef.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace mlir::tt::ttir {
#ifdef TTMLIR_ENABLE_STABLEHLO
void createStableHLOToTTIRPipeline(
OpPassManager &pm, const StableHLOToTTIRPipelineOptions &options) {
if (options.arithDialectConversionsEnabled) {
pm.addPass(createConvertArithToStableHLOPass());
}
pm.addPass(createConvertStableHLOToTTIRPass());
if (options.removeDeadValuesEnabled) {
pm.addPass(mlir::createRemoveDeadValuesPass());
Expand Down
15 changes: 15 additions & 0 deletions test/ttmlir/Conversion/ArithToStableHLO/constant_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_constant attributes {} {
func.func public @test_splat() -> tensor<64xf32> {
%0 = arith.constant dense<0.3> : tensor<64xf32>
// CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]]
return %0 : tensor<64xf32>
}

func.func public @test_multiple() -> tensor<2x2xf32> {
%0 = arith.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
// CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]]
return %0 : tensor<2x2xf32>
}
}

0 comments on commit 684d818

Please sign in to comment.